mirror of
https://github.com/apache/superset.git
synced 2026-05-01 14:04:21 +00:00
Compare commits
1 Commits
semantic-l
...
rls-sqlglo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e7b9f84aa |
@@ -172,7 +172,9 @@ class DatasourceKind(StrEnum):
|
|||||||
PHYSICAL = "physical"
|
PHYSICAL = "physical"
|
||||||
|
|
||||||
|
|
||||||
class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods
|
class BaseDatasource(
|
||||||
|
AuditMixinNullable, ImportExportMixin
|
||||||
|
): # pylint: disable=too-many-public-methods
|
||||||
"""A common interface to objects that are queryable
|
"""A common interface to objects that are queryable
|
||||||
(tables and datasources)"""
|
(tables and datasources)"""
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
|
|||||||
class MssqlEngineSpec(BaseEngineSpec):
|
class MssqlEngineSpec(BaseEngineSpec):
|
||||||
engine = "mssql"
|
engine = "mssql"
|
||||||
engine_name = "Microsoft SQL Server"
|
engine_name = "Microsoft SQL Server"
|
||||||
limit_method = LimitMethod.WRAP_SQL
|
limit_method = LimitMethod.FORCE_LIMIT
|
||||||
max_column_name_length = 128
|
max_column_name_length = 128
|
||||||
allows_cte_in_subquery = False
|
allows_cte_in_subquery = False
|
||||||
allow_limit_clause = False
|
allow_limit_clause = False
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
|
|||||||
|
|
||||||
engine = "teradatasql"
|
engine = "teradatasql"
|
||||||
engine_name = "Teradata"
|
engine_name = "Teradata"
|
||||||
limit_method = LimitMethod.WRAP_SQL
|
limit_method = LimitMethod.FORCE_LIMIT
|
||||||
max_column_name_length = 30 # since 14.10 this is 128
|
max_column_name_length = 30 # since 14.10 this is 128
|
||||||
allow_limit_clause = False
|
allow_limit_clause = False
|
||||||
select_keywords = {"SELECT", "SEL"}
|
select_keywords = {"SELECT", "SEL"}
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""a collection of model-related helper classes and functions"""
|
"""a collection of model-related helper classes and functions"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
@@ -806,7 +808,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
def get_sqla_row_level_filters(
|
def get_sqla_row_level_filters(
|
||||||
self,
|
self,
|
||||||
template_processor: Optional[BaseTemplateProcessor] = None,
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> list[TextClause]:
|
) -> list[TextClause]:
|
||||||
"""
|
"""
|
||||||
Return the appropriate row level security filters for this table and the
|
Return the appropriate row level security filters for this table and the
|
||||||
|
|||||||
@@ -53,9 +53,8 @@ from superset.models.sql_lab import Query
|
|||||||
from superset.result_set import SupersetResultSet
|
from superset.result_set import SupersetResultSet
|
||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
CtasMethod,
|
CtasMethod,
|
||||||
insert_rls_as_subquery,
|
|
||||||
insert_rls_in_predicate,
|
|
||||||
ParsedQuery,
|
ParsedQuery,
|
||||||
|
SQLStatement,
|
||||||
Table,
|
Table,
|
||||||
)
|
)
|
||||||
from superset.sqllab.limiting_factor import LimitingFactor
|
from superset.sqllab.limiting_factor import LimitingFactor
|
||||||
@@ -205,67 +204,49 @@ def execute_sql_statement( # pylint: disable=too-many-statements
|
|||||||
database: Database = query.database
|
database: Database = query.database
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
|
|
||||||
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
|
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
|
||||||
|
|
||||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
default_schema = database.get_default_schema_for_query(query)
|
||||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
parsed_statement = parsed_statement.apply_rls(query.catalog, default_schema)
|
||||||
# safer, but not supported in all databases.
|
|
||||||
insert_rls = (
|
|
||||||
insert_rls_as_subquery
|
|
||||||
if database.db_engine_spec.allows_subqueries
|
|
||||||
and database.db_engine_spec.allows_alias_in_select
|
|
||||||
else insert_rls_in_predicate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert any applicable RLS predicates
|
if parsed_statement.is_dml() and not database.allow_dml:
|
||||||
parsed_query = ParsedQuery(
|
|
||||||
str(
|
|
||||||
insert_rls(
|
|
||||||
parsed_query._parsed[0], # pylint: disable=protected-access
|
|
||||||
database.id,
|
|
||||||
query.schema,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
engine=db_engine_spec.engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = parsed_query.stripped()
|
|
||||||
|
|
||||||
# This is a test to see if the query is being
|
|
||||||
# limited by either the dropdown or the sql.
|
|
||||||
# We are testing to see if more rows exist than the limit.
|
|
||||||
increased_limit = None if query.limit is None else query.limit + 1
|
|
||||||
|
|
||||||
if not db_engine_spec.is_readonly_query(parsed_query) and not database.allow_dml:
|
|
||||||
raise SupersetErrorException(
|
raise SupersetErrorException(
|
||||||
SupersetError(
|
SupersetError(
|
||||||
message=__("Only SELECT statements are allowed against this database."),
|
message=__("DML statements are not allowed in this database."),
|
||||||
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
|
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
|
||||||
level=ErrorLevel.ERROR,
|
level=ErrorLevel.ERROR,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if apply_ctas:
|
if apply_ctas:
|
||||||
if not query.tmp_table_name:
|
if not query.tmp_table_name:
|
||||||
start_dttm = datetime.fromtimestamp(query.start_time)
|
start_dttm = datetime.fromtimestamp(query.start_time)
|
||||||
query.tmp_table_name = (
|
query.tmp_table_name = (
|
||||||
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
|
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
|
||||||
)
|
)
|
||||||
sql = parsed_query.as_create_table(
|
parsed_statement = parsed_statement.as_create_table(
|
||||||
query.tmp_table_name,
|
Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
|
||||||
schema_name=query.tmp_schema_name,
|
|
||||||
method=query.ctas_method,
|
method=query.ctas_method,
|
||||||
)
|
)
|
||||||
query.select_as_cta_used = True
|
query.select_as_cta_used = True
|
||||||
|
|
||||||
|
increased_limit = None if query.limit is None else query.limit + 1
|
||||||
|
|
||||||
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
|
# Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true
|
||||||
if db_engine_spec.is_select_query(parsed_query) and not (
|
if parsed_statement.is_select() and not (
|
||||||
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
|
query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
|
||||||
):
|
):
|
||||||
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
|
if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
|
||||||
query.limit = SQL_MAX_ROW
|
query.limit = SQL_MAX_ROW
|
||||||
sql = apply_limit_if_exists(database, increased_limit, query, sql)
|
|
||||||
|
if query.limit:
|
||||||
|
# Increase limit by one so we can test if there are more rows when the
|
||||||
|
# database returns exactly the number of rows requested by the user.
|
||||||
|
parsed_statement = parsed_statement.apply_limit(increased_limit)
|
||||||
|
|
||||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||||
|
sql = parsed_statement.format(strip=True)
|
||||||
sql = database.mutate_sql_based_on_config(sql)
|
sql = database.mutate_sql_based_on_config(sql)
|
||||||
try:
|
try:
|
||||||
query.executed_sql = sql
|
query.executed_sql = sql
|
||||||
@@ -333,19 +314,6 @@ def execute_sql_statement( # pylint: disable=too-many-statements
|
|||||||
return SupersetResultSet(data, cursor_description, db_engine_spec)
|
return SupersetResultSet(data, cursor_description, db_engine_spec)
|
||||||
|
|
||||||
|
|
||||||
def apply_limit_if_exists(
|
|
||||||
database: Database, increased_limit: Optional[int], query: Query, sql: str
|
|
||||||
) -> str:
|
|
||||||
if query.limit and increased_limit:
|
|
||||||
# We are fetching one more than the requested limit in order
|
|
||||||
# to test whether there are more rows than the limit. According to the DB
|
|
||||||
# Engine support it will choose top or limit parse
|
|
||||||
# Later, the extra row will be dropped before sending
|
|
||||||
# the results back to the user.
|
|
||||||
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
|
|
||||||
return sql
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_payload(
|
def _serialize_payload(
|
||||||
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
|
payload: dict[Any, Any], use_msgpack: Optional[bool] = False
|
||||||
) -> Union[bytes, str]:
|
) -> Union[bytes, str]:
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import sqlparse
|
|||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from jinja2 import nodes
|
from jinja2 import nodes
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from sqlglot import exp, parse, parse_one
|
from sqlglot import exp
|
||||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||||
from sqlglot.errors import ParseError, SqlglotError
|
from sqlglot.errors import ParseError, SqlglotError
|
||||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
@@ -308,7 +308,7 @@ def extract_tables_from_statement(
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
|
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||||
except ParseError:
|
except ParseError:
|
||||||
return set()
|
return set()
|
||||||
sources = pseudo_query.find_all(exp.Table)
|
sources = pseudo_query.find_all(exp.Table)
|
||||||
@@ -433,7 +433,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Format the statement, optionally ommitting comments.
|
Format the statement, optionally ommitting comments.
|
||||||
"""
|
"""
|
||||||
@@ -451,10 +451,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def apply_rls(
|
||||||
|
self,
|
||||||
|
catalog: str | None,
|
||||||
|
schema: str | None,
|
||||||
|
) -> InternalRepresentation:
|
||||||
|
"""
|
||||||
|
Apply Row Level Security to the SQL.
|
||||||
|
|
||||||
|
:param database: The database where the SQL will run
|
||||||
|
:param catalog: The default catalog for non-qualified table names
|
||||||
|
:param schema: The default schema for non-qualified table names
|
||||||
|
:return: The SQL with RLS applied
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.format()
|
return self.format()
|
||||||
|
|
||||||
|
|
||||||
|
class RLSAsPredicate:
|
||||||
|
"""
|
||||||
|
Apply Row Level Security role as a predicate.
|
||||||
|
|
||||||
|
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||||
|
given the RLS rule:
|
||||||
|
|
||||||
|
table: some_table
|
||||||
|
clause: id = 42
|
||||||
|
|
||||||
|
If a user subject to the rule runs the following query:
|
||||||
|
|
||||||
|
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||||
|
|
||||||
|
The query will be modified to:
|
||||||
|
|
||||||
|
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
|
||||||
|
|
||||||
|
This approach is probably less secure than using subqueries, so it's only used for
|
||||||
|
databases without support for subqueries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rules: dict[Table, str]) -> None:
|
||||||
|
self.rules = rules
|
||||||
|
|
||||||
|
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||||
|
if not isinstance(node, exp.Select):
|
||||||
|
return node
|
||||||
|
|
||||||
|
table_node = node.find(exp.Table)
|
||||||
|
if not table_node:
|
||||||
|
return node
|
||||||
|
|
||||||
|
table = Table(
|
||||||
|
str(table_node.this),
|
||||||
|
str(table_node.db) if table_node.db else None,
|
||||||
|
str(table_node.catalog) if table_node.catalog else None,
|
||||||
|
)
|
||||||
|
if predicate := self.rules.get(table):
|
||||||
|
if where := node.args.get("where"):
|
||||||
|
predicate = exp.And(this=predicate, expression=where.this)
|
||||||
|
|
||||||
|
node.set("where", exp.Where(this=predicate))
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
class RLSAsSubquery:
|
||||||
|
def __init__(self, rules: dict[Table, str]) -> None:
|
||||||
|
self.rules = rules
|
||||||
|
|
||||||
|
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||||
|
if not isinstance(node, exp.Table):
|
||||||
|
return node
|
||||||
|
|
||||||
|
table = Table(
|
||||||
|
str(node.this),
|
||||||
|
str(node.db) if node.db else None,
|
||||||
|
str(node.catalog) if node.catalog else None,
|
||||||
|
)
|
||||||
|
if predicate := self.rules.get(table):
|
||||||
|
alias = node.alias
|
||||||
|
node.set("alias", None)
|
||||||
|
return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||||
"""
|
"""
|
||||||
A SQL statement.
|
A SQL statement.
|
||||||
@@ -521,12 +604,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
return extract_tables_from_statement(parsed, dialect)
|
return extract_tables_from_statement(parsed, dialect)
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Pretty-format the SQL statement.
|
Pretty-format the SQL statement.
|
||||||
"""
|
"""
|
||||||
write = Dialect.get_or_raise(self._dialect)
|
write = Dialect.get_or_raise(self._dialect)
|
||||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
output = write.generate(
|
||||||
|
self._parsed,
|
||||||
|
copy=False,
|
||||||
|
comments=comments,
|
||||||
|
pretty=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output.strip(" \t\r\n;") if strip else output
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
"""
|
"""
|
||||||
@@ -543,6 +633,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||||||
for eq in set_item.find_all(exp.EQ)
|
for eq in set_item.find_all(exp.EQ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def apply_rls(
|
||||||
|
self,
|
||||||
|
catalog: str | None,
|
||||||
|
schema: str | None,
|
||||||
|
) -> SQLStatement:
|
||||||
|
"""
|
||||||
|
Apply Row Level Security to the SQL.
|
||||||
|
|
||||||
|
:param catalog: The default catalog for non-qualified table names
|
||||||
|
:param schema: The default schema for non-qualified table names
|
||||||
|
:return: The SQL with RLS applied
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs import load_engine_specs
|
||||||
|
|
||||||
|
statement = self._parsed.copy()
|
||||||
|
|
||||||
|
# collect all relevant RLS rules
|
||||||
|
rules = {}
|
||||||
|
for table in self.tables:
|
||||||
|
if rls := self._get_rls_for_table(table, catalog, schema):
|
||||||
|
rules[table] = rls
|
||||||
|
|
||||||
|
if not rules:
|
||||||
|
return statement
|
||||||
|
|
||||||
|
use_subquery = all(
|
||||||
|
engine_spec.allows_subqueries
|
||||||
|
for engine_spec in load_engine_specs()
|
||||||
|
if engine_spec.engine == self.engine
|
||||||
|
)
|
||||||
|
transformer = RLSAsSubquery(rules) if use_subquery else RLSAsPredicate(rules)
|
||||||
|
|
||||||
|
return SQLStatement(statement.transform(transformer), self.engine)
|
||||||
|
|
||||||
|
def _get_rls_for_table(
|
||||||
|
self,
|
||||||
|
database: Database,
|
||||||
|
table: Table,
|
||||||
|
catalog: str | None,
|
||||||
|
schema: str | None,
|
||||||
|
) -> exp.Expression | None:
|
||||||
|
"""
|
||||||
|
Get the RLS for a table.
|
||||||
|
|
||||||
|
:param table: The table to get the RLS for
|
||||||
|
:param catalog: The default catalog for non-qualified table names
|
||||||
|
:param schema: The default schema for non-qualified table names
|
||||||
|
:return: The RLS for the table
|
||||||
|
"""
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from superset import db
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
|
||||||
|
dataset = db.session.query(SqlaTable).filter(
|
||||||
|
and_(
|
||||||
|
SqlaTable.database_id == database.id,
|
||||||
|
SqlaTable.catalog == table.catalog or catalog,
|
||||||
|
SqlaTable.schema == table.schema or schema,
|
||||||
|
SqlaTable.table_name == table.table,
|
||||||
|
).one_or_none()
|
||||||
|
)
|
||||||
|
if not dataset:
|
||||||
|
return None
|
||||||
|
|
||||||
|
filters = dataset.get_sqla_row_level_filters()
|
||||||
|
if not filters:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rls = and_(*filters).compile(
|
||||||
|
dialect=database.get_dialect(),
|
||||||
|
compile_kwargs={"literal_binds": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
return sqlglot.parse_one(str(rls), dialect=self._dialect)
|
||||||
|
|
||||||
|
def is_dml(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement is DML.
|
||||||
|
|
||||||
|
:return: True if the statement is DML
|
||||||
|
"""
|
||||||
|
for node in self._parsed.walk():
|
||||||
|
if isinstance(
|
||||||
|
node,
|
||||||
|
(
|
||||||
|
exp.Insert,
|
||||||
|
exp.Update,
|
||||||
|
exp.Delete,
|
||||||
|
exp.Merge,
|
||||||
|
exp.Create,
|
||||||
|
exp.Alter,
|
||||||
|
exp.Drop,
|
||||||
|
exp.TruncateTable,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def as_create_table(self, table: Table, method: CtasMethod) -> SQLStatement:
|
||||||
|
"""
|
||||||
|
Convert the statement to a CREATE TABLE statement.
|
||||||
|
"""
|
||||||
|
create_table = exp.Create(
|
||||||
|
this=sqlglot.parse_one(table, into=exp.Table),
|
||||||
|
kind=method.value,
|
||||||
|
expression=self._parsed.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return SQLStatement(create_table, self.engine)
|
||||||
|
|
||||||
|
def is_select(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement is a SELECT statement.
|
||||||
|
|
||||||
|
:return: True if the statement is a SELECT statement
|
||||||
|
"""
|
||||||
|
return isinstance(self._parsed, exp.Select)
|
||||||
|
|
||||||
|
def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
|
||||||
|
"""
|
||||||
|
Apply a limit to the SQL.
|
||||||
|
|
||||||
|
There are 3 strategies to limit queries, defined in the DB engine spec:
|
||||||
|
|
||||||
|
1. `FORCE_LIMIT`: a limit is added to the query, or the existing one is
|
||||||
|
replaced. This is the most efficient, since the database will produce at
|
||||||
|
most the number of rows that Superset will display.
|
||||||
|
2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit is applied
|
||||||
|
to the outer query. This might be inneficient, since the database
|
||||||
|
optimizer might not be able to push the limit down to the inner query.
|
||||||
|
3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are fetched from
|
||||||
|
the database. This is the least efficient, unless the database computes
|
||||||
|
rows as they are read by the cursor, which is unlikely.
|
||||||
|
|
||||||
|
:param limit: The limit to apply
|
||||||
|
:param force: Apply limit even when a lower one is present
|
||||||
|
:return: The SQL with the limit applied
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs import load_engine_specs
|
||||||
|
from superset.db_engine_specs.base import LimitMethod
|
||||||
|
|
||||||
|
methods = {
|
||||||
|
engine_spec.limit_method
|
||||||
|
for engine_spec in load_engine_specs()
|
||||||
|
if engine_spec.engine == self.engine
|
||||||
|
}
|
||||||
|
if not methods:
|
||||||
|
methods = {LimitMethod.FETCH_MANY}
|
||||||
|
|
||||||
|
# When multiple methods are supported, we prefer the more generic one --
|
||||||
|
# usually less efficient.
|
||||||
|
preference = [
|
||||||
|
LimitMethod.FETCH_MANY,
|
||||||
|
LimitMethod.WRAP_SQL,
|
||||||
|
LimitMethod.FORCE_LIMIT,
|
||||||
|
]
|
||||||
|
method = sorted(methods, key=preference.index)[0]
|
||||||
|
|
||||||
|
if not self.is_select() or method == LimitMethod.FETCH_MANY:
|
||||||
|
return SQLStatement(self._parsed.copy(), self.engine)
|
||||||
|
|
||||||
|
if method == LimitMethod.WRAP_SQL:
|
||||||
|
limited = exp.Select(
|
||||||
|
expressions=[exp.Star()],
|
||||||
|
from_=exp.Subquery(subquery=self._parsed.copy(), alias="inner_qry"),
|
||||||
|
limit=exp.Literal.number(limit),
|
||||||
|
)
|
||||||
|
return SQLStatement(limited, self.engine)
|
||||||
|
|
||||||
|
current_limit: int | None = None
|
||||||
|
for node in self._parsed.find_all(exp.Limit):
|
||||||
|
current_limit = int(node.expression.this)
|
||||||
|
break
|
||||||
|
|
||||||
|
if force or current_limit is None or limit < current_limit:
|
||||||
|
return SQLStatement(self._parsed.limit(limit), self.engine)
|
||||||
|
|
||||||
|
return SQLStatement(self._parsed.copy(), self.engine)
|
||||||
|
|
||||||
|
|
||||||
class KQLSplitState(enum.Enum):
|
class KQLSplitState(enum.Enum):
|
||||||
"""
|
"""
|
||||||
@@ -666,11 +936,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||||||
)
|
)
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
def format(self, comments: bool = True, strip: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Pretty-format the SQL statement.
|
Pretty-format the SQL statement.
|
||||||
"""
|
"""
|
||||||
return self._parsed
|
return self._parsed.strip(" \t\r\n;") if strip else self._parsed
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
"""
|
"""
|
||||||
@@ -712,7 +982,10 @@ class SQLScript:
|
|||||||
"""
|
"""
|
||||||
Pretty-format the SQL query.
|
Pretty-format the SQL query.
|
||||||
"""
|
"""
|
||||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
return (
|
||||||
|
";\n".join(statement.format(comments) for statement in self.statements)
|
||||||
|
+ ";"
|
||||||
|
)
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
"""
|
"""
|
||||||
@@ -792,7 +1065,7 @@ class ParsedQuery:
|
|||||||
Note: this uses sqlglot, since it's better at catching more edge cases.
|
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
statements = parse(self.stripped(), dialect=self._dialect)
|
statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
|
||||||
except SqlglotError as ex:
|
except SqlglotError as ex:
|
||||||
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||||
|
|
||||||
@@ -846,7 +1119,7 @@ class ParsedQuery:
|
|||||||
return set()
|
return set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pseudo_query = parse_one(
|
pseudo_query = sqlglot.parse_one(
|
||||||
f"SELECT {literal.this}",
|
f"SELECT {literal.this}",
|
||||||
dialect=self._dialect,
|
dialect=self._dialect,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Optional
|
|||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlglot
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@@ -41,6 +42,8 @@ from superset.sql_parse import (
|
|||||||
insert_rls_in_predicate,
|
insert_rls_in_predicate,
|
||||||
KustoKQLStatement,
|
KustoKQLStatement,
|
||||||
ParsedQuery,
|
ParsedQuery,
|
||||||
|
RLSAsPredicate,
|
||||||
|
RLSAsSubquery,
|
||||||
sanitize_clause,
|
sanitize_clause,
|
||||||
split_kql,
|
split_kql,
|
||||||
SQLScript,
|
SQLScript,
|
||||||
@@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None:
|
|||||||
"""
|
"""
|
||||||
Test that tables inside subselects are parsed correctly.
|
Test that tables inside subselects are parsed correctly.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT sub.*
|
SELECT sub.*
|
||||||
FROM (
|
FROM (
|
||||||
SELECT *
|
SELECT *
|
||||||
@@ -129,10 +133,13 @@ FROM (
|
|||||||
) sub, s2.t2
|
) sub, s2.t2
|
||||||
WHERE sub.resolution = 'NONE'
|
WHERE sub.resolution = 'NONE'
|
||||||
"""
|
"""
|
||||||
) == {Table("t1", "s1"), Table("t2", "s2")}
|
)
|
||||||
|
== {Table("t1", "s1"), Table("t2", "s2")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT sub.*
|
SELECT sub.*
|
||||||
FROM (
|
FROM (
|
||||||
SELECT *
|
SELECT *
|
||||||
@@ -141,10 +148,13 @@ FROM (
|
|||||||
) sub
|
) sub
|
||||||
WHERE sub.resolution = 'NONE'
|
WHERE sub.resolution = 'NONE'
|
||||||
"""
|
"""
|
||||||
) == {Table("t1", "s1")}
|
)
|
||||||
|
== {Table("t1", "s1")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT * FROM t1
|
SELECT * FROM t1
|
||||||
WHERE s11 > ANY (
|
WHERE s11 > ANY (
|
||||||
SELECT COUNT(*) /* no hint */ FROM t2
|
SELECT COUNT(*) /* no hint */ FROM t2
|
||||||
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
)
|
||||||
|
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_select_in_expression() -> None:
|
def test_extract_tables_select_in_expression() -> None:
|
||||||
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
|
|||||||
"""
|
"""
|
||||||
Test that queries selecting arrays work as expected.
|
Test that queries selecting arrays work as expected.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT ARRAY[1, 2, 3] AS my_array
|
SELECT ARRAY[1, 2, 3] AS my_array
|
||||||
FROM t1 LIMIT 10
|
FROM t1 LIMIT 10
|
||||||
"""
|
"""
|
||||||
) == {Table("t1")}
|
)
|
||||||
|
== {Table("t1")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_select_if() -> None:
|
def test_extract_tables_select_if() -> None:
|
||||||
"""
|
"""
|
||||||
Test that queries with an ``IF`` work as expected.
|
Test that queries with an ``IF`` work as expected.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||||
FROM t1 LIMIT 10
|
FROM t1 LIMIT 10
|
||||||
"""
|
"""
|
||||||
) == {Table("t1")}
|
)
|
||||||
|
== {Table("t1")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_with_catalog() -> None:
|
def test_extract_tables_with_catalog() -> None:
|
||||||
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
|
|||||||
"""
|
"""
|
||||||
Test that tables in a ``WHERE`` subquery are parsed correctly.
|
Test that tables in a ``WHERE`` subquery are parsed correctly.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT name
|
SELECT name
|
||||||
FROM t1
|
FROM t1
|
||||||
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t2")}
|
)
|
||||||
|
== {Table("t1"), Table("t2")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT name
|
SELECT name
|
||||||
FROM t1
|
FROM t1
|
||||||
WHERE regionkey IN (SELECT regionkey FROM t2)
|
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t2")}
|
)
|
||||||
|
== {Table("t1"), Table("t2")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT name
|
SELECT name
|
||||||
FROM t1
|
FROM t1
|
||||||
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t2")}
|
)
|
||||||
|
== {Table("t1"), Table("t2")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_describe() -> None:
|
def test_extract_tables_describe() -> None:
|
||||||
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
|
|||||||
"""
|
"""
|
||||||
Test ``SHOW PARTITIONS``.
|
Test ``SHOW PARTITIONS``.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SHOW PARTITIONS FROM orders
|
SHOW PARTITIONS FROM orders
|
||||||
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
||||||
"""
|
"""
|
||||||
) == {Table("orders")}
|
)
|
||||||
|
== {Table("orders")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_join() -> None:
|
def test_extract_tables_join() -> None:
|
||||||
@@ -365,8 +395,9 @@ def test_extract_tables_join() -> None:
|
|||||||
Table("t2"),
|
Table("t2"),
|
||||||
}
|
}
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT a.date, b.name
|
SELECT a.date, b.name
|
||||||
FROM left_table a
|
FROM left_table a
|
||||||
JOIN (
|
JOIN (
|
||||||
@@ -377,10 +408,13 @@ JOIN (
|
|||||||
) b
|
) b
|
||||||
ON a.date = b.date
|
ON a.date = b.date
|
||||||
"""
|
"""
|
||||||
) == {Table("left_table"), Table("right_table")}
|
)
|
||||||
|
== {Table("left_table"), Table("right_table")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT a.date, b.name
|
SELECT a.date, b.name
|
||||||
FROM left_table a
|
FROM left_table a
|
||||||
LEFT INNER JOIN (
|
LEFT INNER JOIN (
|
||||||
@@ -391,10 +425,13 @@ LEFT INNER JOIN (
|
|||||||
) b
|
) b
|
||||||
ON a.date = b.date
|
ON a.date = b.date
|
||||||
"""
|
"""
|
||||||
) == {Table("left_table"), Table("right_table")}
|
)
|
||||||
|
== {Table("left_table"), Table("right_table")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT a.date, b.name
|
SELECT a.date, b.name
|
||||||
FROM left_table a
|
FROM left_table a
|
||||||
RIGHT OUTER JOIN (
|
RIGHT OUTER JOIN (
|
||||||
@@ -405,10 +442,13 @@ RIGHT OUTER JOIN (
|
|||||||
) b
|
) b
|
||||||
ON a.date = b.date
|
ON a.date = b.date
|
||||||
"""
|
"""
|
||||||
) == {Table("left_table"), Table("right_table")}
|
)
|
||||||
|
== {Table("left_table"), Table("right_table")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT a.date, b.name
|
SELECT a.date, b.name
|
||||||
FROM left_table a
|
FROM left_table a
|
||||||
FULL OUTER JOIN (
|
FULL OUTER JOIN (
|
||||||
@@ -419,15 +459,18 @@ FULL OUTER JOIN (
|
|||||||
) b
|
) b
|
||||||
ON a.date = b.date
|
ON a.date = b.date
|
||||||
"""
|
"""
|
||||||
) == {Table("left_table"), Table("right_table")}
|
)
|
||||||
|
== {Table("left_table"), Table("right_table")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_semi_join() -> None:
|
def test_extract_tables_semi_join() -> None:
|
||||||
"""
|
"""
|
||||||
Test ``LEFT SEMI JOIN``.
|
Test ``LEFT SEMI JOIN``.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT a.date, b.name
|
SELECT a.date, b.name
|
||||||
FROM left_table a
|
FROM left_table a
|
||||||
LEFT SEMI JOIN (
|
LEFT SEMI JOIN (
|
||||||
@@ -438,15 +481,18 @@ LEFT SEMI JOIN (
|
|||||||
) b
|
) b
|
||||||
ON a.data = b.date
|
ON a.data = b.date
|
||||||
"""
|
"""
|
||||||
) == {Table("left_table"), Table("right_table")}
|
)
|
||||||
|
== {Table("left_table"), Table("right_table")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_combinations() -> None:
|
def test_extract_tables_combinations() -> None:
|
||||||
"""
|
"""
|
||||||
Test a complex case with nested queries.
|
Test a complex case with nested queries.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT * FROM t1
|
SELECT * FROM t1
|
||||||
WHERE s11 > ANY (
|
WHERE s11 > ANY (
|
||||||
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
||||||
@@ -460,10 +506,13 @@ WHERE s11 > ANY (
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
)
|
||||||
|
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT * FROM (
|
SELECT * FROM (
|
||||||
SELECT * FROM (
|
SELECT * FROM (
|
||||||
SELECT * FROM (
|
SELECT * FROM (
|
||||||
@@ -472,45 +521,56 @@ SELECT * FROM (
|
|||||||
) AS S2
|
) AS S2
|
||||||
) AS S3
|
) AS S3
|
||||||
"""
|
"""
|
||||||
) == {Table("EmployeeS")}
|
)
|
||||||
|
== {Table("EmployeeS")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_with() -> None:
|
def test_extract_tables_with() -> None:
|
||||||
"""
|
"""
|
||||||
Test ``WITH``.
|
Test ``WITH``.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
WITH
|
WITH
|
||||||
x AS (SELECT a FROM t1),
|
x AS (SELECT a FROM t1),
|
||||||
y AS (SELECT a AS b FROM t2),
|
y AS (SELECT a AS b FROM t2),
|
||||||
z AS (SELECT b AS c FROM t3)
|
z AS (SELECT b AS c FROM t3)
|
||||||
SELECT c FROM z
|
SELECT c FROM z
|
||||||
"""
|
"""
|
||||||
) == {Table("t1"), Table("t2"), Table("t3")}
|
)
|
||||||
|
== {Table("t1"), Table("t2"), Table("t3")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
WITH
|
WITH
|
||||||
x AS (SELECT a FROM t1),
|
x AS (SELECT a FROM t1),
|
||||||
y AS (SELECT a AS b FROM x),
|
y AS (SELECT a AS b FROM x),
|
||||||
z AS (SELECT b AS c FROM y)
|
z AS (SELECT b AS c FROM y)
|
||||||
SELECT c FROM z
|
SELECT c FROM z
|
||||||
"""
|
"""
|
||||||
) == {Table("t1")}
|
)
|
||||||
|
== {Table("t1")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_reusing_aliases() -> None:
|
def test_extract_tables_reusing_aliases() -> None:
|
||||||
"""
|
"""
|
||||||
Test that the parser follows aliases.
|
Test that the parser follows aliases.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
with q1 as ( select key from q2 where key = '5'),
|
with q1 as ( select key from q2 where key = '5'),
|
||||||
q2 as ( select key from src where key = '5')
|
q2 as ( select key from src where key = '5')
|
||||||
select * from (select key from q1) a
|
select * from (select key from q1) a
|
||||||
"""
|
"""
|
||||||
) == {Table("src")}
|
)
|
||||||
|
== {Table("src")}
|
||||||
|
)
|
||||||
|
|
||||||
# weird query with circular dependency
|
# weird query with circular dependency
|
||||||
assert (
|
assert (
|
||||||
@@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None:
|
|||||||
"""
|
"""
|
||||||
Test a few complex queries.
|
Test a few complex queries.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT sum(m_examples) AS "sum__m_example"
|
SELECT sum(m_examples) AS "sum__m_example"
|
||||||
FROM (
|
FROM (
|
||||||
SELECT
|
SELECT
|
||||||
@@ -569,23 +630,29 @@ FROM (
|
|||||||
ORDER BY "sum__m_example" DESC
|
ORDER BY "sum__m_example" DESC
|
||||||
LIMIT 10;
|
LIMIT 10;
|
||||||
"""
|
"""
|
||||||
) == {
|
)
|
||||||
Table("my_l_table"),
|
== {
|
||||||
Table("my_b_table"),
|
Table("my_l_table"),
|
||||||
Table("my_t_table"),
|
Table("my_b_table"),
|
||||||
Table("inner_table"),
|
Table("my_t_table"),
|
||||||
}
|
Table("inner_table"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM table_a AS a, table_b AS b, table_c as c
|
FROM table_a AS a, table_b AS b, table_c as c
|
||||||
WHERE a.id = b.id and b.id = c.id
|
WHERE a.id = b.id and b.id = c.id
|
||||||
"""
|
"""
|
||||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
)
|
||||||
|
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT somecol AS somecol
|
SELECT somecol AS somecol
|
||||||
FROM (
|
FROM (
|
||||||
WITH bla AS (
|
WITH bla AS (
|
||||||
@@ -629,51 +696,63 @@ FROM (
|
|||||||
LIMIT 50000
|
LIMIT 50000
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
)
|
||||||
|
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_mixed_from_clause() -> None:
|
def test_extract_tables_mixed_from_clause() -> None:
|
||||||
"""
|
"""
|
||||||
Test that the parser handles a ``FROM`` clause with table and subselect.
|
Test that the parser handles a ``FROM`` clause with table and subselect.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||||
WHERE a.id = b.id and b.id = c.id
|
WHERE a.id = b.id and b.id = c.id
|
||||||
"""
|
"""
|
||||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
)
|
||||||
|
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_nested_select() -> None:
|
def test_extract_tables_nested_select() -> None:
|
||||||
"""
|
"""
|
||||||
Test that the parser handles selects inside functions.
|
Test that the parser handles selects inside functions.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||||
from INFORMATION_SCHEMA.COLUMNS
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||||
""",
|
""",
|
||||||
"mysql",
|
"mysql",
|
||||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
)
|
||||||
|
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
|
)
|
||||||
|
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||||
from INFORMATION_SCHEMA.COLUMNS
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||||
""",
|
""",
|
||||||
"mysql",
|
"mysql",
|
||||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
)
|
||||||
|
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_complex_cte_with_prefix() -> None:
|
def test_extract_tables_complex_cte_with_prefix() -> None:
|
||||||
"""
|
"""
|
||||||
Test that the parser handles CTEs with prefixes.
|
Test that the parser handles CTEs with prefixes.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
||||||
AS (
|
AS (
|
||||||
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
||||||
@@ -685,21 +764,26 @@ FROM CTE__test
|
|||||||
GROUP BY SalesYear, SalesPersonID
|
GROUP BY SalesYear, SalesPersonID
|
||||||
ORDER BY SalesPersonID, SalesYear;
|
ORDER BY SalesPersonID, SalesYear;
|
||||||
"""
|
"""
|
||||||
) == {Table("SalesOrderHeader")}
|
)
|
||||||
|
== {Table("SalesOrderHeader")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
||||||
"""
|
"""
|
||||||
Test that aliases that are keywords are parsed correctly.
|
Test that aliases that are keywords are parsed correctly.
|
||||||
"""
|
"""
|
||||||
assert extract_tables(
|
assert (
|
||||||
"""
|
extract_tables(
|
||||||
|
"""
|
||||||
WITH
|
WITH
|
||||||
f AS (SELECT * FROM foo),
|
f AS (SELECT * FROM foo),
|
||||||
match AS (SELECT * FROM f)
|
match AS (SELECT * FROM f)
|
||||||
SELECT * FROM match
|
SELECT * FROM match
|
||||||
"""
|
"""
|
||||||
) == {Table("foo")}
|
)
|
||||||
|
== {Table("foo")}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_update() -> None:
|
def test_update() -> None:
|
||||||
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
|
|||||||
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
||||||
|
|
||||||
assert len(script.statements) == 2
|
assert len(script.statements) == 2
|
||||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
|
||||||
assert script.statements[0].format() == "SELECT\n 1"
|
assert script.statements[0].format() == "SELECT\n 1"
|
||||||
|
|
||||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
||||||
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
|
|||||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||||
""",
|
""",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sql,rules,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM some_table AS t",
|
||||||
|
{Table("some_table"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||||
|
{Table("some_table"): "id = 42"},
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
|
||||||
|
"WHERE bar = 'baz'"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM schema1.some_table AS t",
|
||||||
|
{Table("some_table", "schema1"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM schema1.some_table AS t",
|
||||||
|
{Table("some_table", "schema2"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM schema1.some_table AS t",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||||
|
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||||
|
{Table("some_table", "schema1", "catalog2"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||||
|
"""
|
||||||
|
Test the `RLSAsSubquery` transformer.
|
||||||
|
"""
|
||||||
|
statement = sqlglot.parse_one(sql)
|
||||||
|
transformer = RLSAsSubquery(rules)
|
||||||
|
assert str(statement.transform(transformer)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sql,rules,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM some_table AS t",
|
||||||
|
{Table("some_table"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM some_table AS t WHERE id = 42",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||||
|
{Table("some_table"): "id = 42"},
|
||||||
|
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||||
|
"""
|
||||||
|
Test the `RLSAsPredicate` transformer.
|
||||||
|
"""
|
||||||
|
statement = sqlglot.parse_one(sql)
|
||||||
|
transformer = RLSAsPredicate(rules)
|
||||||
|
assert str(statement.transform(transformer)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sql,engine,limit,force,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"SELECT TOP 10 * FROM Customers",
|
||||||
|
"teradatasql",
|
||||||
|
5,
|
||||||
|
False,
|
||||||
|
"SELECT\nTOP 5\n *\nFROM Customers",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT TOP 10 * FROM Customers",
|
||||||
|
"teradatasql",
|
||||||
|
15,
|
||||||
|
False,
|
||||||
|
"SELECT\nTOP 10\n *\nFROM Customers",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT TOP 10 * FROM Customers",
|
||||||
|
"teradatasql",
|
||||||
|
15,
|
||||||
|
True,
|
||||||
|
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT TOP 10 * FROM Customers",
|
||||||
|
"mssql",
|
||||||
|
15,
|
||||||
|
True,
|
||||||
|
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_apply_limit(
|
||||||
|
sql: str,
|
||||||
|
engine: str,
|
||||||
|
limit: int,
|
||||||
|
force: bool,
|
||||||
|
expected: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test the `apply_limit` function.
|
||||||
|
"""
|
||||||
|
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected
|
||||||
|
|||||||
Reference in New Issue
Block a user