mirror of
https://github.com/apache/superset.git
synced 2026-05-03 06:54:19 +00:00
Compare commits
1 Commits
docs/testi
...
new-rls
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63813b19ed |
@@ -63,7 +63,8 @@ from superset.constants import TimeGrain as TimeGrainConstants
|
|||||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
||||||
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
from superset.sql.parse import SQLScript, Table
|
||||||
|
from superset.sql_parse import ParsedQuery
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
OAuth2ClientConfig,
|
OAuth2ClientConfig,
|
||||||
OAuth2State,
|
OAuth2State,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
|
|||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetException, SupersetSecurityException
|
from superset.exceptions import SupersetException, SupersetSecurityException
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
from superset.sql_parse import SQLScript
|
from superset.sql.parse import SQLScript
|
||||||
from superset.utils import core as utils, json
|
from superset.utils import core as utils, json
|
||||||
from superset.utils.core import GenericDataType
|
from superset.utils.core import GenericDataType
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,9 @@
|
|||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@@ -304,12 +307,30 @@ class SupersetParseError(SupersetErrorException):
|
|||||||
|
|
||||||
status = 422
|
status = 422
|
||||||
|
|
||||||
def __init__(self, sql: str, engine: Optional[str] = None):
|
def __init__( # pylint: disable=too-many-arguments
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
engine: Optional[str] = None,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
highlight: Optional[str] = None,
|
||||||
|
line: Optional[int] = None,
|
||||||
|
column: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if message is None:
|
||||||
|
parts = [_("Error parsing")]
|
||||||
|
if highlight:
|
||||||
|
parts.append(_(" near '%(highlight)s'", highlight=highlight))
|
||||||
|
if line:
|
||||||
|
parts.append(_(" at line %(line)d", line=line))
|
||||||
|
if column:
|
||||||
|
parts.append(_(":%(column)d", column=column))
|
||||||
|
message = "".join(parts)
|
||||||
|
|
||||||
error = SupersetError(
|
error = SupersetError(
|
||||||
message=_("The SQL is invalid and cannot be parsed."),
|
message=message,
|
||||||
error_type=SupersetErrorType.INVALID_SQL_ERROR,
|
error_type=SupersetErrorType.INVALID_SQL_ERROR,
|
||||||
level=ErrorLevel.ERROR,
|
level=ErrorLevel.ERROR,
|
||||||
extra={"sql": sql, "engine": engine},
|
extra={"sql": sql, "engine": engine, "line": line, "column": column},
|
||||||
)
|
)
|
||||||
super().__init__(error)
|
super().__init__(error)
|
||||||
|
|
||||||
|
|||||||
@@ -68,13 +68,12 @@ from superset.exceptions import (
|
|||||||
)
|
)
|
||||||
from superset.extensions import feature_flag_manager
|
from superset.extensions import feature_flag_manager
|
||||||
from superset.jinja_context import BaseTemplateProcessor
|
from superset.jinja_context import BaseTemplateProcessor
|
||||||
|
from superset.sql.parse import SQLScript, SQLStatement
|
||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
has_table_query,
|
has_table_query,
|
||||||
insert_rls_in_predicate,
|
insert_rls_in_predicate,
|
||||||
ParsedQuery,
|
ParsedQuery,
|
||||||
sanitize_clause,
|
sanitize_clause,
|
||||||
SQLScript,
|
|
||||||
SQLStatement,
|
|
||||||
)
|
)
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
AdhocMetric,
|
AdhocMetric,
|
||||||
|
|||||||
16
superset/sql/__init__.py
Normal file
16
superset/sql/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
648
superset/sql/parse.py
Normal file
648
superset/sql/parse.py
Normal file
@@ -0,0 +1,648 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import urllib.parse
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
import sqlglot
|
||||||
|
from sqlglot import exp
|
||||||
|
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||||
|
from sqlglot.errors import ParseError
|
||||||
|
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
|
|
||||||
|
from superset.exceptions import SupersetParseError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# mapping between DB engine specs and sqlglot dialects
|
||||||
|
SQLGLOT_DIALECTS = {
|
||||||
|
"base": Dialects.DIALECT,
|
||||||
|
"ascend": Dialects.HIVE,
|
||||||
|
"awsathena": Dialects.PRESTO,
|
||||||
|
"bigquery": Dialects.BIGQUERY,
|
||||||
|
"clickhouse": Dialects.CLICKHOUSE,
|
||||||
|
"clickhousedb": Dialects.CLICKHOUSE,
|
||||||
|
"cockroachdb": Dialects.POSTGRES,
|
||||||
|
"couchbase": Dialects.MYSQL,
|
||||||
|
# "crate": ???
|
||||||
|
# "databend": ???
|
||||||
|
"databricks": Dialects.DATABRICKS,
|
||||||
|
# "db2": ???
|
||||||
|
# "dremio": ???
|
||||||
|
"drill": Dialects.DRILL,
|
||||||
|
# "druid": ???
|
||||||
|
"duckdb": Dialects.DUCKDB,
|
||||||
|
# "dynamodb": ???
|
||||||
|
# "elasticsearch": ???
|
||||||
|
# "exa": ???
|
||||||
|
# "firebird": ???
|
||||||
|
# "firebolt": ???
|
||||||
|
"gsheets": Dialects.SQLITE,
|
||||||
|
"hana": Dialects.POSTGRES,
|
||||||
|
"hive": Dialects.HIVE,
|
||||||
|
# "ibmi": ???
|
||||||
|
# "impala": ???
|
||||||
|
# "kustokql": ???
|
||||||
|
# "kylin": ???
|
||||||
|
"mssql": Dialects.TSQL,
|
||||||
|
"mysql": Dialects.MYSQL,
|
||||||
|
"netezza": Dialects.POSTGRES,
|
||||||
|
# "ocient": ???
|
||||||
|
# "odelasticsearch": ???
|
||||||
|
"oracle": Dialects.ORACLE,
|
||||||
|
# "pinot": ???
|
||||||
|
"postgresql": Dialects.POSTGRES,
|
||||||
|
"presto": Dialects.PRESTO,
|
||||||
|
"pydoris": Dialects.DORIS,
|
||||||
|
"redshift": Dialects.REDSHIFT,
|
||||||
|
# "risingwave": ???
|
||||||
|
# "rockset": ???
|
||||||
|
"shillelagh": Dialects.SQLITE,
|
||||||
|
"snowflake": Dialects.SNOWFLAKE,
|
||||||
|
# "solr": ???
|
||||||
|
"spark": Dialects.SPARK,
|
||||||
|
"sqlite": Dialects.SQLITE,
|
||||||
|
"starrocks": Dialects.STARROCKS,
|
||||||
|
"superset": Dialects.SQLITE,
|
||||||
|
"teradatasql": Dialects.TERADATA,
|
||||||
|
"trino": Dialects.TRINO,
|
||||||
|
"vertica": Dialects.POSTGRES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=True, frozen=True)
|
||||||
|
class Table:
|
||||||
|
"""
|
||||||
|
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
table: str
|
||||||
|
schema: str | None = None
|
||||||
|
catalog: str | None = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""
|
||||||
|
Return the fully qualified SQL table name.
|
||||||
|
|
||||||
|
Should not be used for SQL generation, only for logging and debugging, since the
|
||||||
|
quoting is not engine-specific.
|
||||||
|
"""
|
||||||
|
return ".".join(
|
||||||
|
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||||
|
for part in [self.catalog, self.schema, self.table]
|
||||||
|
if part
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
return str(self) == str(other)
|
||||||
|
|
||||||
|
|
||||||
|
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
||||||
|
# an "internal representation", which is the AST of the SQL statement. For most of the
|
||||||
|
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
|
||||||
|
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
|
||||||
|
# store the AST as a string (the original query), and manipulate it with regular
|
||||||
|
# expressions.
|
||||||
|
InternalRepresentation = TypeVar("InternalRepresentation")
|
||||||
|
|
||||||
|
# The base type. This helps type checking the `split_query` method correctly, since each
|
||||||
|
# derived class has a more specific return type (the class itself). This will no longer
|
||||||
|
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
|
||||||
|
# information: https://peps.python.org/pep-0673/
|
||||||
|
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||||
|
"""
|
||||||
|
Base class for SQL statements.
|
||||||
|
|
||||||
|
The class can be instantiated with a string representation of the script or, for
|
||||||
|
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
|
||||||
|
which will split a script in multiple already parsed statements.
|
||||||
|
|
||||||
|
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
|
||||||
|
spec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
statement: str | InternalRepresentation,
|
||||||
|
engine: str,
|
||||||
|
):
|
||||||
|
self._parsed: InternalRepresentation = (
|
||||||
|
self._parse_statement(statement, engine)
|
||||||
|
if isinstance(statement, str)
|
||||||
|
else statement
|
||||||
|
)
|
||||||
|
self.engine = engine
|
||||||
|
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def split_script(
|
||||||
|
cls: type[TBaseSQLStatement],
|
||||||
|
script: str,
|
||||||
|
engine: str,
|
||||||
|
) -> list[TBaseSQLStatement]:
|
||||||
|
"""
|
||||||
|
Split a script into multiple instantiated statements.
|
||||||
|
|
||||||
|
This is a helper function to split a full SQL script into multiple
|
||||||
|
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
|
||||||
|
statements within a script.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_statement(
|
||||||
|
cls,
|
||||||
|
statement: str,
|
||||||
|
engine: str,
|
||||||
|
) -> InternalRepresentation:
|
||||||
|
"""
|
||||||
|
Parse a string containing a single SQL statement, and returns the parsed AST.
|
||||||
|
|
||||||
|
Derived classes should not assume that `statement` contains a single statement,
|
||||||
|
and MUST explicitly validate that. Since this validation is parser dependent the
|
||||||
|
responsibility is left to the children classes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_tables_from_statement(
|
||||||
|
cls,
|
||||||
|
parsed: InternalRepresentation,
|
||||||
|
engine: str,
|
||||||
|
) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all table references in a given statement.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def format(self, comments: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Format the statement, optionally ommitting comments.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
|
"""
|
||||||
|
Return any settings set by the statement.
|
||||||
|
|
||||||
|
For example, for this statement:
|
||||||
|
|
||||||
|
sql> SET foo = 'bar';
|
||||||
|
|
||||||
|
The method should return `{"foo": "'bar'"}`. Note the single quotes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_mutating(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement mutates data (DDL/DML).
|
||||||
|
|
||||||
|
:return: True if the statement mutates data.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.format()
|
||||||
|
|
||||||
|
|
||||||
|
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||||
|
"""
|
||||||
|
A SQL statement.
|
||||||
|
|
||||||
|
This class is used for all engines with dialects that can be parsed using sqlglot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
statement: str | exp.Expression,
|
||||||
|
engine: str,
|
||||||
|
):
|
||||||
|
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
|
super().__init__(statement, engine)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
|
||||||
|
"""
|
||||||
|
Parse helper.
|
||||||
|
"""
|
||||||
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
|
try:
|
||||||
|
return sqlglot.parse(script, dialect=dialect)
|
||||||
|
except sqlglot.errors.ParseError as ex:
|
||||||
|
error = ex.errors[0]
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
engine,
|
||||||
|
highlight=error["highlight"],
|
||||||
|
line=error["line"],
|
||||||
|
column=error["col"],
|
||||||
|
) from ex
|
||||||
|
except sqlglot.errors.SqlglotError as ex:
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
engine,
|
||||||
|
message="Unable to parse script",
|
||||||
|
) from ex
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def split_script(
|
||||||
|
cls,
|
||||||
|
script: str,
|
||||||
|
engine: str,
|
||||||
|
) -> list[SQLStatement]:
|
||||||
|
return [
|
||||||
|
cls(statement, engine)
|
||||||
|
for statement in cls._parse(script, engine)
|
||||||
|
if statement
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_statement(
|
||||||
|
cls,
|
||||||
|
statement: str,
|
||||||
|
engine: str,
|
||||||
|
) -> exp.Expression:
|
||||||
|
"""
|
||||||
|
Parse a single SQL statement.
|
||||||
|
"""
|
||||||
|
statements = cls.split_script(statement, engine)
|
||||||
|
if len(statements) != 1:
|
||||||
|
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||||
|
|
||||||
|
return statements[0]._parsed # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_tables_from_statement(
|
||||||
|
cls,
|
||||||
|
parsed: exp.Expression,
|
||||||
|
engine: str,
|
||||||
|
) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Find all referenced tables.
|
||||||
|
"""
|
||||||
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
|
return extract_tables_from_statement(parsed, dialect)
|
||||||
|
|
||||||
|
def is_mutating(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement mutates data (DDL/DML).
|
||||||
|
|
||||||
|
:return: True if the statement mutates data.
|
||||||
|
"""
|
||||||
|
for node in self._parsed.walk():
|
||||||
|
if isinstance(
|
||||||
|
node,
|
||||||
|
(
|
||||||
|
exp.Insert,
|
||||||
|
exp.Update,
|
||||||
|
exp.Delete,
|
||||||
|
exp.Merge,
|
||||||
|
exp.Create,
|
||||||
|
exp.Drop,
|
||||||
|
exp.TruncateTable,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(node, exp.Command) and node.name == "ALTER":
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
|
||||||
|
# https://www.postgresql.org/docs/current/sql-explain.html
|
||||||
|
if (
|
||||||
|
self._dialect == Dialects.POSTGRES
|
||||||
|
and isinstance(self._parsed, exp.Command)
|
||||||
|
and self._parsed.name == "EXPLAIN"
|
||||||
|
and self._parsed.expression.name.upper().startswith("ANALYZE ")
|
||||||
|
):
|
||||||
|
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
|
||||||
|
return SQLStatement(analyzed_sql, self.engine).is_mutating()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def format(self, comments: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Pretty-format the SQL statement.
|
||||||
|
"""
|
||||||
|
write = Dialect.get_or_raise(self._dialect)
|
||||||
|
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||||
|
|
||||||
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
|
"""
|
||||||
|
Return the settings for the SQL statement.
|
||||||
|
|
||||||
|
>>> statement = SQLStatement("SET foo = 'bar'")
|
||||||
|
>>> statement.get_settings()
|
||||||
|
{"foo": "'bar'"}
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
eq.this.sql(): eq.expression.sql()
|
||||||
|
for set_item in self._parsed.find_all(exp.SetItem)
|
||||||
|
for eq in set_item.find_all(exp.EQ)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KQLSplitState(enum.Enum):
|
||||||
|
"""
|
||||||
|
State machine for splitting a KQL script.
|
||||||
|
|
||||||
|
The state machine keeps track of whether we're inside a string or not, so we
|
||||||
|
don't split the script in a semi-colon that's part of a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
OUTSIDE_STRING = enum.auto()
|
||||||
|
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
||||||
|
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
||||||
|
INSIDE_MULTILINE_STRING = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
def split_kql(kql: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Custom function for splitting KQL statements.
|
||||||
|
"""
|
||||||
|
statements = []
|
||||||
|
state = KQLSplitState.OUTSIDE_STRING
|
||||||
|
statement_start = 0
|
||||||
|
script = kql if kql.endswith(";") else kql + ";"
|
||||||
|
for i, character in enumerate(script):
|
||||||
|
if state == KQLSplitState.OUTSIDE_STRING:
|
||||||
|
if character == ";":
|
||||||
|
statements.append(script[statement_start:i])
|
||||||
|
statement_start = i + 1
|
||||||
|
elif character == "'":
|
||||||
|
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||||
|
elif character == '"':
|
||||||
|
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||||
|
elif character == "`" and script[i - 2 : i] == "``":
|
||||||
|
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||||
|
|
||||||
|
elif (
|
||||||
|
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||||
|
and character == "'"
|
||||||
|
and script[i - 1] != "\\"
|
||||||
|
):
|
||||||
|
state = KQLSplitState.OUTSIDE_STRING
|
||||||
|
|
||||||
|
elif (
|
||||||
|
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||||
|
and character == '"'
|
||||||
|
and script[i - 1] != "\\"
|
||||||
|
):
|
||||||
|
state = KQLSplitState.OUTSIDE_STRING
|
||||||
|
|
||||||
|
elif (
|
||||||
|
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
||||||
|
and character == "`"
|
||||||
|
and script[i - 2 : i] == "``"
|
||||||
|
):
|
||||||
|
state = KQLSplitState.OUTSIDE_STRING
|
||||||
|
|
||||||
|
return statements
|
||||||
|
|
||||||
|
|
||||||
|
class KustoKQLStatement(BaseSQLStatement[str]):
|
||||||
|
"""
|
||||||
|
Special class for Kusto KQL.
|
||||||
|
|
||||||
|
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
|
||||||
|
like this:
|
||||||
|
|
||||||
|
StormEvents
|
||||||
|
| summarize PropertyDamage = sum(DamageProperty) by State
|
||||||
|
| join kind=innerunique PopulationData on State
|
||||||
|
| project State, PropertyDamagePerCapita = PropertyDamage / Population
|
||||||
|
| sort by PropertyDamagePerCapita
|
||||||
|
|
||||||
|
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
|
||||||
|
details about it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def split_script(
|
||||||
|
cls,
|
||||||
|
script: str,
|
||||||
|
engine: str,
|
||||||
|
) -> list[KustoKQLStatement]:
|
||||||
|
"""
|
||||||
|
Split a script at semi-colons.
|
||||||
|
|
||||||
|
Since we don't have a parser, we use a simple state machine based function. See
|
||||||
|
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
|
||||||
|
for more information.
|
||||||
|
"""
|
||||||
|
return [cls(statement, engine) for statement in split_kql(script)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_statement(
|
||||||
|
cls,
|
||||||
|
statement: str,
|
||||||
|
engine: str,
|
||||||
|
) -> str:
|
||||||
|
if engine != "kustokql":
|
||||||
|
raise SupersetParseError(f"Invalid engine: {engine}")
|
||||||
|
|
||||||
|
statements = split_kql(statement)
|
||||||
|
if len(statements) != 1:
|
||||||
|
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||||
|
|
||||||
|
return statements[0].strip()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_tables_from_statement(
|
||||||
|
cls,
|
||||||
|
parsed: str,
|
||||||
|
engine: str,
|
||||||
|
) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all tables referenced in the statement.
|
||||||
|
|
||||||
|
StormEvents
|
||||||
|
| where InjuriesDirect + InjuriesIndirect > 50
|
||||||
|
| join (PopulationData) on State
|
||||||
|
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"Kusto KQL doesn't support table extraction. This means that data access "
|
||||||
|
"roles will not be enforced by Superset in the database."
|
||||||
|
)
|
||||||
|
return set()
|
||||||
|
|
||||||
|
def format(self, comments: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Pretty-format the SQL statement.
|
||||||
|
"""
|
||||||
|
return self._parsed
|
||||||
|
|
||||||
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
|
"""
|
||||||
|
Return the settings for the SQL statement.
|
||||||
|
|
||||||
|
>>> statement = KustoKQLStatement("set querytrace;")
|
||||||
|
>>> statement.get_settings()
|
||||||
|
{"querytrace": True}
|
||||||
|
|
||||||
|
"""
|
||||||
|
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
|
||||||
|
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
|
||||||
|
return {match.group("name"): match.group("value") or True}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def is_mutating(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement mutates data (DDL/DML).
|
||||||
|
|
||||||
|
:return: True if the statement mutates data.
|
||||||
|
"""
|
||||||
|
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
||||||
|
|
||||||
|
|
||||||
|
class SQLScript:
|
||||||
|
"""
|
||||||
|
A SQL script, with 0+ statements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
|
||||||
|
# adds a lot of complexity to Superset, so we should avoid adding new engines to
|
||||||
|
# this data structure.
|
||||||
|
special_engines = {
|
||||||
|
"kustokql": KustoKQLStatement,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
script: str,
|
||||||
|
engine: str,
|
||||||
|
):
|
||||||
|
statement_class = self.special_engines.get(engine, SQLStatement)
|
||||||
|
self.engine = engine
|
||||||
|
self.statements = statement_class.split_script(script, engine)
|
||||||
|
|
||||||
|
def format(self, comments: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Pretty-format the SQL script.
|
||||||
|
"""
|
||||||
|
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||||
|
|
||||||
|
def get_settings(self) -> dict[str, str | bool]:
|
||||||
|
"""
|
||||||
|
Return the settings for the SQL script.
|
||||||
|
|
||||||
|
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
||||||
|
>>> statement.get_settings()
|
||||||
|
{"foo": "'baz'"}
|
||||||
|
|
||||||
|
"""
|
||||||
|
settings: dict[str, str | bool] = {}
|
||||||
|
for statement in self.statements:
|
||||||
|
settings.update(statement.get_settings())
|
||||||
|
|
||||||
|
return settings
|
||||||
|
|
||||||
|
def has_mutation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the script contains mutating statements.
|
||||||
|
|
||||||
|
:return: True if the script contains mutating statements
|
||||||
|
"""
|
||||||
|
return any(statement.is_mutating() for statement in self.statements)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tables_from_statement(
|
||||||
|
statement: exp.Expression,
|
||||||
|
dialect: Dialects | None,
|
||||||
|
) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all table references in a single statement.
|
||||||
|
|
||||||
|
Please not that this is not trivial; consider the following queries:
|
||||||
|
|
||||||
|
DESCRIBE some_table;
|
||||||
|
SHOW PARTITIONS FROM some_table;
|
||||||
|
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||||
|
|
||||||
|
See the unit tests for other tricky cases.
|
||||||
|
"""
|
||||||
|
sources: Iterable[exp.Table]
|
||||||
|
|
||||||
|
if isinstance(statement, exp.Describe):
|
||||||
|
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||||
|
# query for all tables.
|
||||||
|
sources = statement.find_all(exp.Table)
|
||||||
|
elif isinstance(statement, exp.Command):
|
||||||
|
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||||
|
# `SELECT` statetement in order to extract tables.
|
||||||
|
literal = statement.find(exp.Literal)
|
||||||
|
if not literal:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||||
|
except ParseError:
|
||||||
|
return set()
|
||||||
|
sources = pseudo_query.find_all(exp.Table)
|
||||||
|
else:
|
||||||
|
sources = [
|
||||||
|
source
|
||||||
|
for scope in traverse_scope(statement)
|
||||||
|
for source in scope.sources.values()
|
||||||
|
if isinstance(source, exp.Table) and not is_cte(source, scope)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
Table(
|
||||||
|
source.name,
|
||||||
|
source.db if source.db != "" else None,
|
||||||
|
source.catalog if source.catalog != "" else None,
|
||||||
|
)
|
||||||
|
for source in sources
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_cte(source: exp.Table, scope: Scope) -> bool:
|
||||||
|
"""
|
||||||
|
Is the source a CTE?
|
||||||
|
|
||||||
|
CTEs in the parent scope look like tables (and are represented by
|
||||||
|
exp.Table objects), but should not be considered as such;
|
||||||
|
otherwise a user with access to table `foo` could access any table
|
||||||
|
with a query like this:
|
||||||
|
|
||||||
|
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||||
|
|
||||||
|
"""
|
||||||
|
parent_sources = scope.parent.sources if scope.parent else {}
|
||||||
|
ctes_in_scope = {
|
||||||
|
name
|
||||||
|
for name, parent_scope in parent_sources.items()
|
||||||
|
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.name in ctes_in_scope
|
||||||
@@ -51,13 +51,12 @@ from superset.extensions import celery_app, event_logger
|
|||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
from superset.result_set import SupersetResultSet
|
from superset.result_set import SupersetResultSet
|
||||||
|
from superset.sql.parse import SQLStatement, Table
|
||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
CtasMethod,
|
CtasMethod,
|
||||||
insert_rls_as_subquery,
|
insert_rls_as_subquery,
|
||||||
insert_rls_in_predicate,
|
insert_rls_in_predicate,
|
||||||
ParsedQuery,
|
ParsedQuery,
|
||||||
SQLStatement,
|
|
||||||
Table,
|
|
||||||
)
|
)
|
||||||
from superset.sqllab.limiting_factor import LimitingFactor
|
from superset.sqllab.limiting_factor import LimitingFactor
|
||||||
from superset.sqllab.utils import write_ipc_buffer
|
from superset.sqllab.utils import write_ipc_buffer
|
||||||
|
|||||||
@@ -19,23 +19,16 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import urllib.parse
|
from collections.abc import Iterator
|
||||||
from collections.abc import Iterable, Iterator
|
from typing import Any, cast, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
|
|
||||||
|
|
||||||
import sqlglot
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
from jinja2 import nodes
|
from jinja2 import nodes
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from sqlglot import exp, parse, parse_one
|
from sqlglot.dialects.dialect import Dialects
|
||||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
|
||||||
from sqlglot.errors import ParseError, SqlglotError
|
|
||||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
|
||||||
from sqlparse import keywords
|
from sqlparse import keywords
|
||||||
from sqlparse.lexer import Lexer
|
from sqlparse.lexer import Lexer
|
||||||
from sqlparse.sql import (
|
from sqlparse.sql import (
|
||||||
@@ -68,6 +61,7 @@ from superset.exceptions import (
|
|||||||
SupersetParseError,
|
SupersetParseError,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
|
from superset.sql.parse import extract_tables_from_statement, SQLScript, Table
|
||||||
from superset.utils.backports import StrEnum
|
from superset.utils.backports import StrEnum
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -226,7 +220,9 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
|||||||
|
|
||||||
|
|
||||||
def check_sql_functions_exist(
|
def check_sql_functions_exist(
|
||||||
sql: str, function_list: set[str], engine: str | None = None
|
sql: str,
|
||||||
|
function_list: set[str],
|
||||||
|
engine: str = "base",
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the SQL statement contains any of the specified functions.
|
Check if the SQL statement contains any of the specified functions.
|
||||||
@@ -238,7 +234,7 @@ def check_sql_functions_exist(
|
|||||||
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
||||||
|
|
||||||
|
|
||||||
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
|
||||||
"""
|
"""
|
||||||
Strips comments from a SQL statement, does a simple test first
|
Strips comments from a SQL statement, does a simple test first
|
||||||
to avoid always instantiating the expensive ParsedQuery constructor
|
to avoid always instantiating the expensive ParsedQuery constructor
|
||||||
@@ -255,554 +251,18 @@ def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
|
||||||
class Table:
|
|
||||||
"""
|
|
||||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
table: str
|
|
||||||
schema: str | None = None
|
|
||||||
catalog: str | None = None
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""
|
|
||||||
Return the fully qualified SQL table name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return ".".join(
|
|
||||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
|
||||||
for part in [self.catalog, self.schema, self.table]
|
|
||||||
if part
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, __o: object) -> bool:
|
|
||||||
return str(self) == str(__o)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tables_from_statement(
|
|
||||||
statement: exp.Expression,
|
|
||||||
dialect: Dialects | None,
|
|
||||||
) -> set[Table]:
|
|
||||||
"""
|
|
||||||
Extract all table references in a single statement.
|
|
||||||
|
|
||||||
Please not that this is not trivial; consider the following queries:
|
|
||||||
|
|
||||||
DESCRIBE some_table;
|
|
||||||
SHOW PARTITIONS FROM some_table;
|
|
||||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
|
||||||
|
|
||||||
See the unit tests for other tricky cases.
|
|
||||||
"""
|
|
||||||
sources: Iterable[exp.Table]
|
|
||||||
|
|
||||||
if isinstance(statement, exp.Describe):
|
|
||||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
|
||||||
# query for all tables.
|
|
||||||
sources = statement.find_all(exp.Table)
|
|
||||||
elif isinstance(statement, exp.Command):
|
|
||||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
|
||||||
# `SELECT` statetement in order to extract tables.
|
|
||||||
literal = statement.find(exp.Literal)
|
|
||||||
if not literal:
|
|
||||||
return set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
|
|
||||||
except ParseError:
|
|
||||||
return set()
|
|
||||||
sources = pseudo_query.find_all(exp.Table)
|
|
||||||
else:
|
|
||||||
sources = [
|
|
||||||
source
|
|
||||||
for scope in traverse_scope(statement)
|
|
||||||
for source in scope.sources.values()
|
|
||||||
if isinstance(source, exp.Table) and not is_cte(source, scope)
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
Table(
|
|
||||||
source.name,
|
|
||||||
source.db if source.db != "" else None,
|
|
||||||
source.catalog if source.catalog != "" else None,
|
|
||||||
)
|
|
||||||
for source in sources
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def is_cte(source: exp.Table, scope: Scope) -> bool:
|
|
||||||
"""
|
|
||||||
Is the source a CTE?
|
|
||||||
|
|
||||||
CTEs in the parent scope look like tables (and are represented by
|
|
||||||
exp.Table objects), but should not be considered as such;
|
|
||||||
otherwise a user with access to table `foo` could access any table
|
|
||||||
with a query like this:
|
|
||||||
|
|
||||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
|
||||||
|
|
||||||
"""
|
|
||||||
parent_sources = scope.parent.sources if scope.parent else {}
|
|
||||||
ctes_in_scope = {
|
|
||||||
name
|
|
||||||
for name, parent_scope in parent_sources.items()
|
|
||||||
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
|
|
||||||
}
|
|
||||||
|
|
||||||
return source.name in ctes_in_scope
|
|
||||||
|
|
||||||
|
|
||||||
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
|
||||||
# an "internal representation", which is the AST of the SQL statement. For most of the
|
|
||||||
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
|
|
||||||
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
|
|
||||||
# store the AST as a string (the original query), and manipulate it with regular
|
|
||||||
# expressions.
|
|
||||||
InternalRepresentation = TypeVar("InternalRepresentation")
|
|
||||||
|
|
||||||
# The base type. This helps type checking the `split_query` method correctly, since each
|
|
||||||
# derived class has a more specific return type (the class itself). This will no longer
|
|
||||||
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
|
|
||||||
# information: https://peps.python.org/pep-0673/
|
|
||||||
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSQLStatement(Generic[InternalRepresentation]):
|
|
||||||
"""
|
|
||||||
Base class for SQL statements.
|
|
||||||
|
|
||||||
The class can be instantiated with a string representation of the query or, for
|
|
||||||
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
|
|
||||||
which will split a query in multiple already parsed statements.
|
|
||||||
|
|
||||||
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
|
|
||||||
spec.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
statement: str | InternalRepresentation,
|
|
||||||
engine: str,
|
|
||||||
):
|
|
||||||
self._parsed: InternalRepresentation = (
|
|
||||||
self._parse_statement(statement, engine)
|
|
||||||
if isinstance(statement, str)
|
|
||||||
else statement
|
|
||||||
)
|
|
||||||
self.engine = engine
|
|
||||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def split_query(
|
|
||||||
cls: type[TBaseSQLStatement],
|
|
||||||
query: str,
|
|
||||||
engine: str,
|
|
||||||
) -> list[TBaseSQLStatement]:
|
|
||||||
"""
|
|
||||||
Split a query into multiple instantiated statements.
|
|
||||||
|
|
||||||
This is a helper function to split a full SQL query into multiple
|
|
||||||
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
|
|
||||||
statements within a query.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _parse_statement(
|
|
||||||
cls,
|
|
||||||
statement: str,
|
|
||||||
engine: str,
|
|
||||||
) -> InternalRepresentation:
|
|
||||||
"""
|
|
||||||
Parse a string containing a single SQL statement, and returns the parsed AST.
|
|
||||||
|
|
||||||
Derived classes should not assume that `statement` contains a single statement,
|
|
||||||
and MUST explicitly validate that. Since this validation is parser dependent the
|
|
||||||
responsibility is left to the children classes.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_tables_from_statement(
|
|
||||||
cls,
|
|
||||||
parsed: InternalRepresentation,
|
|
||||||
engine: str,
|
|
||||||
) -> set[Table]:
|
|
||||||
"""
|
|
||||||
Extract all table references in a given statement.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
|
||||||
"""
|
|
||||||
Format the statement, optionally ommitting comments.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
|
||||||
"""
|
|
||||||
Return any settings set by the statement.
|
|
||||||
|
|
||||||
For example, for this statement:
|
|
||||||
|
|
||||||
sql> SET foo = 'bar';
|
|
||||||
|
|
||||||
The method should return `{"foo": "'bar'"}`. Note the single quotes.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def is_mutating(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the statement mutates data (DDL/DML).
|
|
||||||
|
|
||||||
:return: True if the statement mutates data.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.format()
|
|
||||||
|
|
||||||
|
|
||||||
class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|
||||||
"""
|
|
||||||
A SQL statement.
|
|
||||||
|
|
||||||
This class is used for all engines with dialects that can be parsed using sqlglot.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
statement: str | exp.Expression,
|
|
||||||
engine: str,
|
|
||||||
):
|
|
||||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
|
||||||
super().__init__(statement, engine)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def split_query(
|
|
||||||
cls,
|
|
||||||
query: str,
|
|
||||||
engine: str,
|
|
||||||
) -> list[SQLStatement]:
|
|
||||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
|
||||||
|
|
||||||
try:
|
|
||||||
statements = sqlglot.parse(query, dialect=dialect)
|
|
||||||
except sqlglot.errors.ParseError as ex:
|
|
||||||
raise SupersetParseError("Unable to split query") from ex
|
|
||||||
|
|
||||||
return [cls(statement, engine) for statement in statements if statement]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _parse_statement(
|
|
||||||
cls,
|
|
||||||
statement: str,
|
|
||||||
engine: str,
|
|
||||||
) -> exp.Expression:
|
|
||||||
"""
|
|
||||||
Parse a single SQL statement.
|
|
||||||
"""
|
|
||||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
|
||||||
|
|
||||||
# We could parse with `sqlglot.parse_one` to get a single statement, but we need
|
|
||||||
# to verify that the string contains exactly one statement.
|
|
||||||
try:
|
|
||||||
statements = sqlglot.parse(statement, dialect=dialect)
|
|
||||||
except sqlglot.errors.ParseError as ex:
|
|
||||||
raise SupersetParseError("Unable to split query") from ex
|
|
||||||
|
|
||||||
statements = [statement for statement in statements if statement]
|
|
||||||
if len(statements) != 1:
|
|
||||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
|
||||||
|
|
||||||
return statements[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_tables_from_statement(
|
|
||||||
cls,
|
|
||||||
parsed: exp.Expression,
|
|
||||||
engine: str,
|
|
||||||
) -> set[Table]:
|
|
||||||
"""
|
|
||||||
Find all referenced tables.
|
|
||||||
"""
|
|
||||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
|
||||||
return extract_tables_from_statement(parsed, dialect)
|
|
||||||
|
|
||||||
def is_mutating(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the statement mutates data (DDL/DML).
|
|
||||||
|
|
||||||
:return: True if the statement mutates data.
|
|
||||||
"""
|
|
||||||
for node in self._parsed.walk():
|
|
||||||
if isinstance(
|
|
||||||
node,
|
|
||||||
(
|
|
||||||
exp.Insert,
|
|
||||||
exp.Update,
|
|
||||||
exp.Delete,
|
|
||||||
exp.Merge,
|
|
||||||
exp.Create,
|
|
||||||
exp.Drop,
|
|
||||||
exp.TruncateTable,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(node, exp.Command) and node.name == "ALTER":
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
|
|
||||||
# https://www.postgresql.org/docs/current/sql-explain.html
|
|
||||||
if (
|
|
||||||
self._dialect == Dialects.POSTGRES
|
|
||||||
and isinstance(self._parsed, exp.Command)
|
|
||||||
and self._parsed.name == "EXPLAIN"
|
|
||||||
and self._parsed.expression.name.upper().startswith("ANALYZE ")
|
|
||||||
):
|
|
||||||
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
|
|
||||||
return SQLStatement(analyzed_sql, self.engine).is_mutating()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
|
||||||
"""
|
|
||||||
Pretty-format the SQL statement.
|
|
||||||
"""
|
|
||||||
write = Dialect.get_or_raise(self._dialect)
|
|
||||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
|
||||||
"""
|
|
||||||
Return the settings for the SQL statement.
|
|
||||||
|
|
||||||
>>> statement = SQLStatement("SET foo = 'bar'")
|
|
||||||
>>> statement.get_settings()
|
|
||||||
{"foo": "'bar'"}
|
|
||||||
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
eq.this.sql(): eq.expression.sql()
|
|
||||||
for set_item in self._parsed.find_all(exp.SetItem)
|
|
||||||
for eq in set_item.find_all(exp.EQ)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class KQLSplitState(enum.Enum):
|
|
||||||
"""
|
|
||||||
State machine for splitting a KQL query.
|
|
||||||
|
|
||||||
The state machine keeps track of whether we're inside a string or not, so we
|
|
||||||
don't split the query in a semi-colon that's part of a string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTSIDE_STRING = enum.auto()
|
|
||||||
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
|
||||||
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
|
||||||
INSIDE_MULTILINE_STRING = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
def split_kql(kql: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
Custom function for splitting KQL statements.
|
|
||||||
"""
|
|
||||||
statements = []
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
statement_start = 0
|
|
||||||
query = kql if kql.endswith(";") else kql + ";"
|
|
||||||
for i, character in enumerate(query):
|
|
||||||
if state == KQLSplitState.OUTSIDE_STRING:
|
|
||||||
if character == ";":
|
|
||||||
statements.append(query[statement_start:i])
|
|
||||||
statement_start = i + 1
|
|
||||||
elif character == "'":
|
|
||||||
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
|
||||||
elif character == '"':
|
|
||||||
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
|
||||||
elif character == "`" and query[i - 2 : i] == "``":
|
|
||||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
|
||||||
and character == "'"
|
|
||||||
and query[i - 1] != "\\"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
|
||||||
and character == '"'
|
|
||||||
and query[i - 1] != "\\"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
|
||||||
and character == "`"
|
|
||||||
and query[i - 2 : i] == "``"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
return statements
|
|
||||||
|
|
||||||
|
|
||||||
class KustoKQLStatement(BaseSQLStatement[str]):
|
|
||||||
"""
|
|
||||||
Special class for Kusto KQL.
|
|
||||||
|
|
||||||
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
|
|
||||||
like this:
|
|
||||||
|
|
||||||
StormEvents
|
|
||||||
| summarize PropertyDamage = sum(DamageProperty) by State
|
|
||||||
| join kind=innerunique PopulationData on State
|
|
||||||
| project State, PropertyDamagePerCapita = PropertyDamage / Population
|
|
||||||
| sort by PropertyDamagePerCapita
|
|
||||||
|
|
||||||
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
|
|
||||||
details about it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def split_query(
|
|
||||||
cls,
|
|
||||||
query: str,
|
|
||||||
engine: str,
|
|
||||||
) -> list[KustoKQLStatement]:
|
|
||||||
"""
|
|
||||||
Split a query at semi-colons.
|
|
||||||
|
|
||||||
Since we don't have a parser, we use a simple state machine based function. See
|
|
||||||
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
|
|
||||||
for more information.
|
|
||||||
"""
|
|
||||||
return [cls(statement, engine) for statement in split_kql(query)]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _parse_statement(
|
|
||||||
cls,
|
|
||||||
statement: str,
|
|
||||||
engine: str,
|
|
||||||
) -> str:
|
|
||||||
if engine != "kustokql":
|
|
||||||
raise SupersetParseError(f"Invalid engine: {engine}")
|
|
||||||
|
|
||||||
statements = split_kql(statement)
|
|
||||||
if len(statements) != 1:
|
|
||||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
|
||||||
|
|
||||||
return statements[0].strip()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
|
|
||||||
"""
|
|
||||||
Extract all tables referenced in the statement.
|
|
||||||
|
|
||||||
StormEvents
|
|
||||||
| where InjuriesDirect + InjuriesIndirect > 50
|
|
||||||
| join (PopulationData) on State
|
|
||||||
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
|
|
||||||
|
|
||||||
"""
|
|
||||||
logger.warning(
|
|
||||||
"Kusto KQL doesn't support table extraction. This means that data access "
|
|
||||||
"roles will not be enforced by Superset in the database."
|
|
||||||
)
|
|
||||||
return set()
|
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
|
||||||
"""
|
|
||||||
Pretty-format the SQL statement.
|
|
||||||
"""
|
|
||||||
return self._parsed
|
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
|
||||||
"""
|
|
||||||
Return the settings for the SQL statement.
|
|
||||||
|
|
||||||
>>> statement = KustoKQLStatement("set querytrace;")
|
|
||||||
>>> statement.get_settings()
|
|
||||||
{"querytrace": True}
|
|
||||||
|
|
||||||
"""
|
|
||||||
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
|
|
||||||
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
|
|
||||||
return {match.group("name"): match.group("value") or True}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def is_mutating(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the statement mutates data (DDL/DML).
|
|
||||||
|
|
||||||
:return: True if the statement mutates data.
|
|
||||||
"""
|
|
||||||
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
|
||||||
|
|
||||||
|
|
||||||
class SQLScript:
|
|
||||||
"""
|
|
||||||
A SQL script, with 0+ statements.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
|
|
||||||
# adds a lot of complexity to Superset, so we should avoid adding new engines to
|
|
||||||
# this data structure.
|
|
||||||
special_engines = {
|
|
||||||
"kustokql": KustoKQLStatement,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
engine: str,
|
|
||||||
):
|
|
||||||
statement_class = self.special_engines.get(engine, SQLStatement)
|
|
||||||
self.statements = statement_class.split_query(query, engine)
|
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
|
||||||
"""
|
|
||||||
Pretty-format the SQL query.
|
|
||||||
"""
|
|
||||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
|
||||||
|
|
||||||
def get_settings(self) -> dict[str, str | bool]:
|
|
||||||
"""
|
|
||||||
Return the settings for the SQL query.
|
|
||||||
|
|
||||||
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
|
||||||
>>> statement.get_settings()
|
|
||||||
{"foo": "'baz'"}
|
|
||||||
|
|
||||||
"""
|
|
||||||
settings: dict[str, str | bool] = {}
|
|
||||||
for statement in self.statements:
|
|
||||||
settings.update(statement.get_settings())
|
|
||||||
|
|
||||||
return settings
|
|
||||||
|
|
||||||
def has_mutation(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the script contains mutating statements.
|
|
||||||
|
|
||||||
:return: True if the script contains mutating statements
|
|
||||||
"""
|
|
||||||
return any(statement.is_mutating() for statement in self.statements)
|
|
||||||
|
|
||||||
|
|
||||||
class ParsedQuery:
|
class ParsedQuery:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sql_statement: str,
|
sql_statement: str,
|
||||||
strip_comments: bool = False,
|
strip_comments: bool = False,
|
||||||
engine: str | None = None,
|
engine: str = "base",
|
||||||
):
|
):
|
||||||
if strip_comments:
|
if strip_comments:
|
||||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||||
|
|
||||||
self.sql: str = sql_statement
|
self.sql: str = sql_statement
|
||||||
|
self._engine = engine
|
||||||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||||
self._tables: set[Table] = set()
|
self._tables: set[Table] = set()
|
||||||
self._alias_names: set[str] = set()
|
self._alias_names: set[str] = set()
|
||||||
@@ -854,24 +314,18 @@ class ParsedQuery:
|
|||||||
Note: this uses sqlglot, since it's better at catching more edge cases.
|
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
statements = parse(self.stripped(), dialect=self._dialect)
|
statements = [
|
||||||
except SqlglotError as ex:
|
statement._parsed # pylint: disable=protected-access
|
||||||
|
for statement in SQLScript(self.stripped(), self._engine).statements
|
||||||
|
]
|
||||||
|
except SupersetParseError as ex:
|
||||||
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||||
|
|
||||||
message = (
|
|
||||||
"Error parsing near '{highlight}' at line {line}:{col}".format( # pylint: disable=consider-using-f-string
|
|
||||||
**ex.errors[0]
|
|
||||||
)
|
|
||||||
if isinstance(ex, ParseError)
|
|
||||||
else str(ex)
|
|
||||||
)
|
|
||||||
|
|
||||||
raise SupersetSecurityException(
|
raise SupersetSecurityException(
|
||||||
SupersetError(
|
SupersetError(
|
||||||
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
|
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
|
||||||
message=__(
|
message=__(
|
||||||
"You may have an error in your SQL statement. {message}"
|
"You may have an error in your SQL statement. {message}"
|
||||||
).format(message=message),
|
).format(message=ex.error.message),
|
||||||
level=ErrorLevel.ERROR,
|
level=ErrorLevel.ERROR,
|
||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
@@ -883,77 +337,6 @@ class ParsedQuery:
|
|||||||
if statement
|
if statement
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
|
|
||||||
"""
|
|
||||||
Extract all table references in a single statement.
|
|
||||||
|
|
||||||
Please not that this is not trivial; consider the following queries:
|
|
||||||
|
|
||||||
DESCRIBE some_table;
|
|
||||||
SHOW PARTITIONS FROM some_table;
|
|
||||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
|
||||||
|
|
||||||
See the unit tests for other tricky cases.
|
|
||||||
"""
|
|
||||||
sources: Iterable[exp.Table]
|
|
||||||
|
|
||||||
if isinstance(statement, exp.Describe):
|
|
||||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
|
||||||
# query for all tables.
|
|
||||||
sources = statement.find_all(exp.Table)
|
|
||||||
elif isinstance(statement, exp.Command):
|
|
||||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
|
||||||
# `SELECT` statetement in order to extract tables.
|
|
||||||
if not (literal := statement.find(exp.Literal)):
|
|
||||||
return set()
|
|
||||||
|
|
||||||
try:
|
|
||||||
pseudo_query = parse_one(
|
|
||||||
f"SELECT {literal.this}",
|
|
||||||
dialect=self._dialect,
|
|
||||||
)
|
|
||||||
sources = pseudo_query.find_all(exp.Table)
|
|
||||||
except SqlglotError:
|
|
||||||
return set()
|
|
||||||
else:
|
|
||||||
sources = [
|
|
||||||
source
|
|
||||||
for scope in traverse_scope(statement)
|
|
||||||
for source in scope.sources.values()
|
|
||||||
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
Table(
|
|
||||||
source.name,
|
|
||||||
source.db if source.db != "" else None,
|
|
||||||
source.catalog if source.catalog != "" else None,
|
|
||||||
)
|
|
||||||
for source in sources
|
|
||||||
}
|
|
||||||
|
|
||||||
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
|
|
||||||
"""
|
|
||||||
Is the source a CTE?
|
|
||||||
|
|
||||||
CTEs in the parent scope look like tables (and are represented by
|
|
||||||
exp.Table objects), but should not be considered as such;
|
|
||||||
otherwise a user with access to table `foo` could access any table
|
|
||||||
with a query like this:
|
|
||||||
|
|
||||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
|
||||||
|
|
||||||
"""
|
|
||||||
parent_sources = scope.parent.sources if scope.parent else {}
|
|
||||||
ctes_in_scope = {
|
|
||||||
name
|
|
||||||
for name, parent_scope in parent_sources.items()
|
|
||||||
if isinstance(parent_scope, Scope)
|
|
||||||
and parent_scope.scope_type == ScopeType.CTE
|
|
||||||
}
|
|
||||||
|
|
||||||
return source.name in ctes_in_scope
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit(self) -> int | None:
|
def limit(self) -> int | None:
|
||||||
return self._limit
|
return self._limit
|
||||||
|
|||||||
@@ -35,8 +35,8 @@ from superset.daos.query import QueryDAO
|
|||||||
from superset.extensions import event_logger
|
from superset.extensions import event_logger
|
||||||
from superset.jinja_context import get_template_processor
|
from superset.jinja_context import get_template_processor
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
|
from superset.sql.parse import SQLScript
|
||||||
from superset.sql_lab import get_sql_results
|
from superset.sql_lab import get_sql_results
|
||||||
from superset.sql_parse import SQLScript
|
|
||||||
from superset.sqllab.command_status import SqlJsonExecutionStatus
|
from superset.sqllab.command_status import SqlJsonExecutionStatus
|
||||||
from superset.sqllab.exceptions import (
|
from superset.sqllab.exceptions import (
|
||||||
QueryIsForbiddenToAccessException,
|
QueryIsForbiddenToAccessException,
|
||||||
|
|||||||
@@ -107,7 +107,6 @@ def cache_dashboard_thumbnail(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
|
||||||
@celery_app.task(name="cache_dashboard_screenshot", soft_time_limit=300)
|
@celery_app.task(name="cache_dashboard_screenshot", soft_time_limit=300)
|
||||||
def cache_dashboard_screenshot(
|
def cache_dashboard_screenshot(
|
||||||
dashboard_id: int,
|
dashboard_id: int,
|
||||||
|
|||||||
16
tests/unit_tests/sql/__init__.py
Normal file
16
tests/unit_tests/sql/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
920
tests/unit_tests/sql/parse_tests.py
Normal file
920
tests/unit_tests/sql/parse_tests.py
Normal file
@@ -0,0 +1,920 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from superset.exceptions import SupersetParseError
|
||||||
|
from superset.sql.parse import (
|
||||||
|
extract_tables_from_statement,
|
||||||
|
KustoKQLStatement,
|
||||||
|
split_kql,
|
||||||
|
SQLGLOT_DIALECTS,
|
||||||
|
SQLScript,
|
||||||
|
SQLStatement,
|
||||||
|
Table,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_table() -> None:
|
||||||
|
"""
|
||||||
|
Test the `Table` class and its string conversion.
|
||||||
|
|
||||||
|
Special characters in the table, schema, or catalog name should be escaped correctly.
|
||||||
|
"""
|
||||||
|
assert str(Table("tbname")) == "tbname"
|
||||||
|
assert str(Table("tbname", "schemaname")) == "schemaname.tbname"
|
||||||
|
assert (
|
||||||
|
str(Table("tbname", "schemaname", "catalogname"))
|
||||||
|
== "catalogname.schemaname.tbname"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
str(Table("table.name", "schema/name", "catalog\nname"))
|
||||||
|
== "catalog%0Aname.schema%2Fname.table%2Ename"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
|
||||||
|
"""
|
||||||
|
Helper function to extract tables from SQL.
|
||||||
|
"""
|
||||||
|
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||||
|
return {
|
||||||
|
table
|
||||||
|
for statement in SQLScript(sql, engine).statements
|
||||||
|
for table in extract_tables_from_statement(statement._parsed, dialect)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_from_sql() -> None:
|
||||||
|
"""
|
||||||
|
Test that referenced tables are parsed correctly from the SQL.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")}
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")}
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")}
|
||||||
|
|
||||||
|
# underscore
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")}
|
||||||
|
|
||||||
|
# quotes
|
||||||
|
assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")}
|
||||||
|
|
||||||
|
# unicode
|
||||||
|
assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == {
|
||||||
|
Table("tb_name")
|
||||||
|
}
|
||||||
|
|
||||||
|
# columns
|
||||||
|
assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == {
|
||||||
|
Table("tb_name")
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# named table
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"SELECT a.date, a.field FROM left_table a LIMIT 10"
|
||||||
|
) == {Table("left_table")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
|
||||||
|
) == {Table("forbidden_table")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"select * from (select * from forbidden_table) forbidden_table"
|
||||||
|
) == {Table("forbidden_table")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_subselect() -> None:
|
||||||
|
"""
|
||||||
|
Test that tables inside subselects are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT sub.*
|
||||||
|
FROM (
|
||||||
|
SELECT *
|
||||||
|
FROM s1.t1
|
||||||
|
WHERE day_of_week = 'Friday'
|
||||||
|
) sub, s2.t2
|
||||||
|
WHERE sub.resolution = 'NONE'
|
||||||
|
"""
|
||||||
|
) == {Table("t1", "s1"), Table("t2", "s2")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT sub.*
|
||||||
|
FROM (
|
||||||
|
SELECT *
|
||||||
|
FROM s1.t1
|
||||||
|
WHERE day_of_week = 'Friday'
|
||||||
|
) sub
|
||||||
|
WHERE sub.resolution = 'NONE'
|
||||||
|
"""
|
||||||
|
) == {Table("t1", "s1")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT * FROM t1
|
||||||
|
WHERE s11 > ANY (
|
||||||
|
SELECT COUNT(*) /* no hint */ FROM t2
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT * FROM t3
|
||||||
|
WHERE ROW(5*t2.s1,77)=(
|
||||||
|
SELECT 50,11*s1 FROM t4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_select_in_expression() -> None:
|
||||||
|
"""
|
||||||
|
Test that parser works with `SELECT`s used as expressions.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
|
||||||
|
) == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_parenthesis() -> None:
|
||||||
|
"""
|
||||||
|
Test that parenthesis are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_with_schema() -> None:
|
||||||
|
"""
|
||||||
|
Test that schemas are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == {
|
||||||
|
Table("tbname", "schemaname")
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == {
|
||||||
|
Table("tbname", "schemaname")
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == {
|
||||||
|
Table("tbname", "schemaname")
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == {
|
||||||
|
Table("tbname", "schemaname")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_union() -> None:
|
||||||
|
"""
|
||||||
|
Test that `UNION` queries work as expected.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
|
||||||
|
) == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_select_from_values() -> None:
|
||||||
|
"""
|
||||||
|
Test that selecting from values returns no tables.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_select_array() -> None:
|
||||||
|
"""
|
||||||
|
Test that queries selecting arrays work as expected.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT ARRAY[1, 2, 3] AS my_array
|
||||||
|
FROM t1 LIMIT 10
|
||||||
|
"""
|
||||||
|
) == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_select_if() -> None:
|
||||||
|
"""
|
||||||
|
Test that queries with an `IF` work as expected.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||||
|
FROM t1 LIMIT 10
|
||||||
|
"""
|
||||||
|
) == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_with_catalog() -> None:
|
||||||
|
"""
|
||||||
|
Test that catalogs are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == {
|
||||||
|
Table("tbname", "schemaname", "catalogname")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_illdefined() -> None:
|
||||||
|
"""
|
||||||
|
Test that ill-defined tables return an empty set.
|
||||||
|
"""
|
||||||
|
with pytest.raises(SupersetParseError) as excinfo:
|
||||||
|
extract_tables_from_sql("SELECT * FROM schemaname.")
|
||||||
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:25"
|
||||||
|
|
||||||
|
with pytest.raises(SupersetParseError) as excinfo:
|
||||||
|
extract_tables_from_sql("SELECT * FROM catalogname.schemaname.")
|
||||||
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:37"
|
||||||
|
|
||||||
|
with pytest.raises(SupersetParseError) as excinfo:
|
||||||
|
extract_tables_from_sql("SELECT * FROM catalogname..")
|
||||||
|
assert str(excinfo.value) == "Error parsing near '.' at line 1:27"
|
||||||
|
|
||||||
|
with pytest.raises(SupersetParseError) as excinfo:
|
||||||
|
extract_tables_from_sql('SELECT * FROM "tbname')
|
||||||
|
assert str(excinfo.value) == "Unable to parse script"
|
||||||
|
|
||||||
|
# odd edge case that works
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == {
|
||||||
|
Table(table="tbname", schema=None, catalog="catalogname")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_show_tables_from() -> None:
|
||||||
|
"""
|
||||||
|
Test `SHOW TABLES FROM`.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_show_columns_from() -> None:
|
||||||
|
"""
|
||||||
|
Test `SHOW COLUMNS FROM`.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_where_subquery() -> None:
|
||||||
|
"""
|
||||||
|
Test that tables in a `WHERE` subquery are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT name
|
||||||
|
FROM t1
|
||||||
|
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t2")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT name
|
||||||
|
FROM t1
|
||||||
|
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t2")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT name
|
||||||
|
FROM t1
|
||||||
|
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t2")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_describe() -> None:
|
||||||
|
"""
|
||||||
|
Test `DESCRIBE`.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_show_partitions() -> None:
|
||||||
|
"""
|
||||||
|
Test `SHOW PARTITIONS`.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SHOW PARTITIONS FROM orders
|
||||||
|
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
||||||
|
"""
|
||||||
|
) == {Table("orders")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_join() -> None:
|
||||||
|
"""
|
||||||
|
Test joins.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
|
||||||
|
) == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT a.date, b.name
|
||||||
|
FROM left_table a
|
||||||
|
JOIN (
|
||||||
|
SELECT
|
||||||
|
CAST((b.year) as VARCHAR) date,
|
||||||
|
name
|
||||||
|
FROM right_table
|
||||||
|
) b
|
||||||
|
ON a.date = b.date
|
||||||
|
"""
|
||||||
|
) == {Table("left_table"), Table("right_table")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT a.date, b.name
|
||||||
|
FROM left_table a
|
||||||
|
LEFT INNER JOIN (
|
||||||
|
SELECT
|
||||||
|
CAST((b.year) as VARCHAR) date,
|
||||||
|
name
|
||||||
|
FROM right_table
|
||||||
|
) b
|
||||||
|
ON a.date = b.date
|
||||||
|
"""
|
||||||
|
) == {Table("left_table"), Table("right_table")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT a.date, b.name
|
||||||
|
FROM left_table a
|
||||||
|
RIGHT OUTER JOIN (
|
||||||
|
SELECT
|
||||||
|
CAST((b.year) as VARCHAR) date,
|
||||||
|
name
|
||||||
|
FROM right_table
|
||||||
|
) b
|
||||||
|
ON a.date = b.date
|
||||||
|
"""
|
||||||
|
) == {Table("left_table"), Table("right_table")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT a.date, b.name
|
||||||
|
FROM left_table a
|
||||||
|
FULL OUTER JOIN (
|
||||||
|
SELECT
|
||||||
|
CAST((b.year) as VARCHAR) date,
|
||||||
|
name
|
||||||
|
FROM right_table
|
||||||
|
) b
|
||||||
|
ON a.date = b.date
|
||||||
|
"""
|
||||||
|
) == {Table("left_table"), Table("right_table")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_semi_join() -> None:
|
||||||
|
"""
|
||||||
|
Test `LEFT SEMI JOIN`.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT a.date, b.name
|
||||||
|
FROM left_table a
|
||||||
|
LEFT SEMI JOIN (
|
||||||
|
SELECT
|
||||||
|
CAST((b.year) as VARCHAR) date,
|
||||||
|
name
|
||||||
|
FROM right_table
|
||||||
|
) b
|
||||||
|
ON a.data = b.date
|
||||||
|
"""
|
||||||
|
) == {Table("left_table"), Table("right_table")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_combinations() -> None:
|
||||||
|
"""
|
||||||
|
Test a complex case with nested queries.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT * FROM t1
|
||||||
|
WHERE s11 > ANY (
|
||||||
|
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
||||||
|
SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a
|
||||||
|
) tmp_join
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT * FROM t3
|
||||||
|
WHERE ROW(5*t3.s1,77)=(
|
||||||
|
SELECT 50,11*s1 FROM t4
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT * FROM (
|
||||||
|
SELECT * FROM (
|
||||||
|
SELECT * FROM (
|
||||||
|
SELECT * FROM EmployeeS
|
||||||
|
) AS S1
|
||||||
|
) AS S2
|
||||||
|
) AS S3
|
||||||
|
"""
|
||||||
|
) == {Table("EmployeeS")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_with() -> None:
|
||||||
|
"""
|
||||||
|
Test `WITH`.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
WITH
|
||||||
|
x AS (SELECT a FROM t1),
|
||||||
|
y AS (SELECT a AS b FROM t2),
|
||||||
|
z AS (SELECT b AS c FROM t3)
|
||||||
|
SELECT c FROM z
|
||||||
|
"""
|
||||||
|
) == {Table("t1"), Table("t2"), Table("t3")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
WITH
|
||||||
|
x AS (SELECT a FROM t1),
|
||||||
|
y AS (SELECT a AS b FROM x),
|
||||||
|
z AS (SELECT b AS c FROM y)
|
||||||
|
SELECT c FROM z
|
||||||
|
"""
|
||||||
|
) == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_reusing_aliases() -> None:
|
||||||
|
"""
|
||||||
|
Test that the parser follows aliases.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
with q1 as ( select key from q2 where key = '5'),
|
||||||
|
q2 as ( select key from src where key = '5')
|
||||||
|
select * from (select key from q1) a
|
||||||
|
"""
|
||||||
|
) == {Table("src")}
|
||||||
|
|
||||||
|
# weird query with circular dependency
|
||||||
|
assert (
|
||||||
|
extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
with src as ( select key from q2 where key = '5'),
|
||||||
|
q2 as ( select key from src where key = '5')
|
||||||
|
select * from (select key from src) a
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
== set()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_multistatement() -> None:
|
||||||
|
"""
|
||||||
|
Test that the parser works with multiple statements.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == {
|
||||||
|
Table("t1"),
|
||||||
|
Table("t2"),
|
||||||
|
}
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"ADD JAR file:///hive.jar; SELECT * FROM t1;",
|
||||||
|
engine="hive",
|
||||||
|
) == {Table("t1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_complex() -> None:
|
||||||
|
"""
|
||||||
|
Test a few complex queries.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT sum(m_examples) AS "sum__m_example"
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
COUNT(DISTINCT id_userid) AS m_examples,
|
||||||
|
some_more_info
|
||||||
|
FROM my_b_table b
|
||||||
|
JOIN my_t_table t ON b.ds=t.ds
|
||||||
|
JOIN my_l_table l ON b.uid=l.uid
|
||||||
|
WHERE
|
||||||
|
b.rid IN (
|
||||||
|
SELECT other_col
|
||||||
|
FROM inner_table
|
||||||
|
)
|
||||||
|
AND l.bla IN ('x', 'y')
|
||||||
|
GROUP BY 2
|
||||||
|
ORDER BY 2 ASC
|
||||||
|
) AS "meh"
|
||||||
|
ORDER BY "sum__m_example" DESC
|
||||||
|
LIMIT 10;
|
||||||
|
"""
|
||||||
|
) == {
|
||||||
|
Table("my_l_table"),
|
||||||
|
Table("my_b_table"),
|
||||||
|
Table("my_t_table"),
|
||||||
|
Table("inner_table"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM table_a AS a, table_b AS b, table_c as c
|
||||||
|
WHERE a.id = b.id and b.id = c.id
|
||||||
|
"""
|
||||||
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT somecol AS somecol
|
||||||
|
FROM (
|
||||||
|
WITH bla AS (
|
||||||
|
SELECT col_a
|
||||||
|
FROM a
|
||||||
|
WHERE
|
||||||
|
1=1
|
||||||
|
AND column_of_choice NOT IN (
|
||||||
|
SELECT interesting_col
|
||||||
|
FROM b
|
||||||
|
)
|
||||||
|
),
|
||||||
|
rb AS (
|
||||||
|
SELECT yet_another_column
|
||||||
|
FROM (
|
||||||
|
SELECT a
|
||||||
|
FROM c
|
||||||
|
GROUP BY the_other_col
|
||||||
|
) not_table
|
||||||
|
LEFT JOIN bla foo
|
||||||
|
ON foo.prop = not_table.bad_col0
|
||||||
|
WHERE 1=1
|
||||||
|
GROUP BY
|
||||||
|
not_table.bad_col1 ,
|
||||||
|
not_table.bad_col2 ,
|
||||||
|
ORDER BY not_table.bad_col_3 DESC ,
|
||||||
|
not_table.bad_col4 ,
|
||||||
|
not_table.bad_col5
|
||||||
|
)
|
||||||
|
SELECT random_col
|
||||||
|
FROM d
|
||||||
|
WHERE 1=1
|
||||||
|
UNION ALL SELECT even_more_cols
|
||||||
|
FROM e
|
||||||
|
WHERE 1=1
|
||||||
|
UNION ALL SELECT lets_go_deeper
|
||||||
|
FROM f
|
||||||
|
WHERE 1=1
|
||||||
|
WHERE 2=2
|
||||||
|
GROUP BY last_col
|
||||||
|
LIMIT 50000
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_mixed_from_clause() -> None:
|
||||||
|
"""
|
||||||
|
Test that the parser handles a `FROM` clause with table and subselect.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||||
|
WHERE a.id = b.id and b.id = c.id
|
||||||
|
"""
|
||||||
|
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_nested_select() -> None:
|
||||||
|
"""
|
||||||
|
Test that the parser handles selects inside functions.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||||
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||||
|
""",
|
||||||
|
"mysql",
|
||||||
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
|
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||||
|
from INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||||
|
""",
|
||||||
|
"mysql",
|
||||||
|
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_complex_cte_with_prefix() -> None:
|
||||||
|
"""
|
||||||
|
Test that the parser handles CTEs with prefixes.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
||||||
|
AS (
|
||||||
|
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
||||||
|
FROM SalesOrderHeader
|
||||||
|
WHERE SalesPersonID IS NOT NULL
|
||||||
|
)
|
||||||
|
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
|
||||||
|
FROM CTE__test
|
||||||
|
GROUP BY SalesYear, SalesPersonID
|
||||||
|
ORDER BY SalesPersonID, SalesYear;
|
||||||
|
"""
|
||||||
|
) == {Table("SalesOrderHeader")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
||||||
|
"""
|
||||||
|
Test that aliases that are keywords are parsed correctly.
|
||||||
|
"""
|
||||||
|
assert extract_tables_from_sql(
|
||||||
|
"""
|
||||||
|
WITH
|
||||||
|
f AS (SELECT * FROM foo),
|
||||||
|
match AS (SELECT * FROM f)
|
||||||
|
SELECT * FROM match
|
||||||
|
"""
|
||||||
|
) == {Table("foo")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlscript() -> None:
|
||||||
|
"""
|
||||||
|
Test the `SQLScript` class.
|
||||||
|
"""
|
||||||
|
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
||||||
|
|
||||||
|
assert len(script.statements) == 2
|
||||||
|
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
||||||
|
assert script.statements[0].format() == "SELECT\n 1"
|
||||||
|
|
||||||
|
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
||||||
|
assert script.get_settings() == {"a": "2"}
|
||||||
|
|
||||||
|
query = SQLScript(
|
||||||
|
"""set querytrace;
|
||||||
|
Events | take 100""",
|
||||||
|
"kustokql",
|
||||||
|
)
|
||||||
|
assert query.get_settings() == {"querytrace": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlstatement() -> None:
|
||||||
|
"""
|
||||||
|
Test the `SQLStatement` class.
|
||||||
|
"""
|
||||||
|
statement = SQLStatement(
|
||||||
|
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
|
||||||
|
"sqlite",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert statement.tables == {
|
||||||
|
Table(table="table1", schema=None, catalog=None),
|
||||||
|
Table(table="table2", schema=None, catalog=None),
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
statement.format()
|
||||||
|
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
||||||
|
)
|
||||||
|
|
||||||
|
statement = SQLStatement("SET a=1", "sqlite")
|
||||||
|
assert statement.get_settings() == {"a": "1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_kustokqlstatement_split_script() -> None:
|
||||||
|
"""
|
||||||
|
Test the `KustoKQLStatement` split method.
|
||||||
|
"""
|
||||||
|
statements = KustoKQLStatement.split_script(
|
||||||
|
"""
|
||||||
|
let totalPagesPerDay = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp)
|
||||||
|
| summarize count() by Day;
|
||||||
|
let materializedScope = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp);
|
||||||
|
let cachedResult = materialize(materializedScope);
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day1 = Day
|
||||||
|
| join kind = inner
|
||||||
|
(
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day2 = Day
|
||||||
|
)
|
||||||
|
on Page
|
||||||
|
| where Day2 > Day1
|
||||||
|
| summarize count() by Day1, Day2
|
||||||
|
| join kind = inner
|
||||||
|
totalPagesPerDay
|
||||||
|
on $left.Day1 == $right.Day
|
||||||
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||||
|
""",
|
||||||
|
"kustokql",
|
||||||
|
)
|
||||||
|
assert len(statements) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_kustokqlstatement_with_program() -> None:
|
||||||
|
"""
|
||||||
|
Test the `KustoKQLStatement` split method when the KQL has a program.
|
||||||
|
"""
|
||||||
|
statements = KustoKQLStatement.split_script(
|
||||||
|
"""
|
||||||
|
print program = ```
|
||||||
|
public class Program {
|
||||||
|
public static void Main() {
|
||||||
|
System.Console.WriteLine("Hello!");
|
||||||
|
}
|
||||||
|
}```
|
||||||
|
""",
|
||||||
|
"kustokql",
|
||||||
|
)
|
||||||
|
assert len(statements) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_kustokqlstatement_with_set() -> None:
|
||||||
|
"""
|
||||||
|
Test the `KustoKQLStatement` split method when the KQL has a set command.
|
||||||
|
"""
|
||||||
|
statements = KustoKQLStatement.split_script(
|
||||||
|
"""
|
||||||
|
set querytrace;
|
||||||
|
Events | take 100
|
||||||
|
""",
|
||||||
|
"kustokql",
|
||||||
|
)
|
||||||
|
assert len(statements) == 2
|
||||||
|
assert statements[0].format() == "set querytrace"
|
||||||
|
assert statements[1].format() == "Events | take 100"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kql,statements",
|
||||||
|
[
|
||||||
|
('print banner=strcat("Hello", ", ", "World!")', 1),
|
||||||
|
(r"print 'O\'Malley\'s'", 1),
|
||||||
|
(r"print 'O\'Mal;ley\'s'", 1),
|
||||||
|
("print ```foo;\nbar;\nbaz;```\n", 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
||||||
|
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_kql() -> None:
|
||||||
|
"""
|
||||||
|
Test the `split_kql` function.
|
||||||
|
"""
|
||||||
|
kql = """
|
||||||
|
let totalPagesPerDay = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp)
|
||||||
|
| summarize count() by Day;
|
||||||
|
let materializedScope = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp);
|
||||||
|
let cachedResult = materialize(materializedScope);
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day1 = Day
|
||||||
|
| join kind = inner
|
||||||
|
(
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day2 = Day
|
||||||
|
)
|
||||||
|
on Page
|
||||||
|
| where Day2 > Day1
|
||||||
|
| summarize count() by Day1, Day2
|
||||||
|
| join kind = inner
|
||||||
|
totalPagesPerDay
|
||||||
|
on $left.Day1 == $right.Day
|
||||||
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||||
|
"""
|
||||||
|
assert split_kql(kql) == [
|
||||||
|
"""
|
||||||
|
let totalPagesPerDay = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp)
|
||||||
|
| summarize count() by Day""",
|
||||||
|
"""
|
||||||
|
let materializedScope = PageViews
|
||||||
|
| summarize by Page, Day = startofday(Timestamp)""",
|
||||||
|
"""
|
||||||
|
let cachedResult = materialize(materializedScope)""",
|
||||||
|
"""
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day1 = Day
|
||||||
|
| join kind = inner
|
||||||
|
(
|
||||||
|
cachedResult
|
||||||
|
| project Page, Day2 = Day
|
||||||
|
)
|
||||||
|
on Page
|
||||||
|
| where Day2 > Day1
|
||||||
|
| summarize count() by Day1, Day2
|
||||||
|
| join kind = inner
|
||||||
|
totalPagesPerDay
|
||||||
|
on $left.Day1 == $right.Day
|
||||||
|
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("engine", "sql", "expected"),
|
||||||
|
[
|
||||||
|
# SQLite tests
|
||||||
|
("sqlite", "SELECT 1", False),
|
||||||
|
("sqlite", "INSERT INTO foo VALUES (1)", True),
|
||||||
|
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
|
||||||
|
("sqlite", "DELETE FROM foo WHERE id = 1", True),
|
||||||
|
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
|
||||||
|
("sqlite", "DROP TABLE foo", True),
|
||||||
|
("sqlite", "EXPLAIN SELECT * FROM foo", False),
|
||||||
|
("sqlite", "PRAGMA table_info(foo)", False),
|
||||||
|
("postgresql", "SELECT 1", False),
|
||||||
|
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
|
||||||
|
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
|
||||||
|
("postgresql", "DELETE FROM foo WHERE id = 1", True),
|
||||||
|
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
|
||||||
|
("postgresql", "DROP TABLE foo", True),
|
||||||
|
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
|
||||||
|
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
|
||||||
|
("postgresql", "SHOW search_path", False),
|
||||||
|
("postgresql", "SET search_path TO public", False),
|
||||||
|
(
|
||||||
|
"postgres",
|
||||||
|
"""
|
||||||
|
with source as (
|
||||||
|
select 1 as one
|
||||||
|
)
|
||||||
|
select * from source
|
||||||
|
""",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
("trino", "SELECT 1", False),
|
||||||
|
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
|
||||||
|
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
|
||||||
|
("trino", "DELETE FROM foo WHERE id = 1", True),
|
||||||
|
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
|
||||||
|
("trino", "DROP TABLE foo", True),
|
||||||
|
("trino", "EXPLAIN SELECT * FROM foo", False),
|
||||||
|
("trino", "SHOW SCHEMAS", False),
|
||||||
|
("trino", "SET SESSION optimization_level = '3'", False),
|
||||||
|
("kustokql", "tbl | limit 100", False),
|
||||||
|
("kustokql", "let foo = 1; tbl | where bar == foo", False),
|
||||||
|
("kustokql", ".show tables", False),
|
||||||
|
("kustokql", "print 1", False),
|
||||||
|
("kustokql", "set querytrace; Events | take 100", False),
|
||||||
|
("kustokql", ".drop table foo", True),
|
||||||
|
("kustokql", ".set-or-append table foo <| bar", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
||||||
|
"""
|
||||||
|
Test the `has_mutation` method.
|
||||||
|
"""
|
||||||
|
assert SQLScript(sql, engine).has_mutation() == expected
|
||||||
@@ -30,6 +30,7 @@ from superset.exceptions import (
|
|||||||
QueryClauseValidationException,
|
QueryClauseValidationException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
|
from superset.sql.parse import Table
|
||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
add_table_name,
|
add_table_name,
|
||||||
check_sql_functions_exist,
|
check_sql_functions_exist,
|
||||||
@@ -39,18 +40,13 @@ from superset.sql_parse import (
|
|||||||
has_table_query,
|
has_table_query,
|
||||||
insert_rls_as_subquery,
|
insert_rls_as_subquery,
|
||||||
insert_rls_in_predicate,
|
insert_rls_in_predicate,
|
||||||
KustoKQLStatement,
|
|
||||||
ParsedQuery,
|
ParsedQuery,
|
||||||
sanitize_clause,
|
sanitize_clause,
|
||||||
split_kql,
|
|
||||||
SQLScript,
|
|
||||||
SQLStatement,
|
|
||||||
strip_comments_from_sql,
|
strip_comments_from_sql,
|
||||||
Table,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
|
def extract_tables(query: str, engine: str = "base") -> set[Table]:
|
||||||
"""
|
"""
|
||||||
Helper function to extract tables referenced in a query.
|
Helper function to extract tables referenced in a query.
|
||||||
"""
|
"""
|
||||||
@@ -285,7 +281,7 @@ def test_extract_tables_illdefined() -> None:
|
|||||||
extract_tables('SELECT * FROM "tbname')
|
extract_tables('SELECT * FROM "tbname')
|
||||||
assert (
|
assert (
|
||||||
str(excinfo.value)
|
str(excinfo.value)
|
||||||
== "You may have an error in your SQL statement. Error tokenizing 'SELECT * FROM \"tbnam'"
|
== "You may have an error in your SQL statement. Unable to parse script"
|
||||||
)
|
)
|
||||||
|
|
||||||
# odd edge case that works
|
# odd edge case that works
|
||||||
@@ -1834,49 +1830,6 @@ SELECT * FROM t"""
|
|||||||
assert ParsedQuery("USE foo; SELECT * FROM bar").is_select()
|
assert ParsedQuery("USE foo; SELECT * FROM bar").is_select()
|
||||||
|
|
||||||
|
|
||||||
def test_sqlquery() -> None:
|
|
||||||
"""
|
|
||||||
Test the `SQLScript` class.
|
|
||||||
"""
|
|
||||||
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
|
||||||
|
|
||||||
assert len(script.statements) == 2
|
|
||||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
|
||||||
assert script.statements[0].format() == "SELECT\n 1"
|
|
||||||
|
|
||||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
|
||||||
assert script.get_settings() == {"a": "2"}
|
|
||||||
|
|
||||||
query = SQLScript(
|
|
||||||
"""set querytrace;
|
|
||||||
Events | take 100""",
|
|
||||||
"kustokql",
|
|
||||||
)
|
|
||||||
assert query.get_settings() == {"querytrace": True}
|
|
||||||
|
|
||||||
|
|
||||||
def test_sqlstatement() -> None:
|
|
||||||
"""
|
|
||||||
Test the `SQLStatement` class.
|
|
||||||
"""
|
|
||||||
statement = SQLStatement(
|
|
||||||
"SELECT * FROM table1 UNION ALL SELECT * FROM table2",
|
|
||||||
"sqlite",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert statement.tables == {
|
|
||||||
Table(table="table1", schema=None, catalog=None),
|
|
||||||
Table(table="table2", schema=None, catalog=None),
|
|
||||||
}
|
|
||||||
assert (
|
|
||||||
statement.format()
|
|
||||||
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
|
||||||
)
|
|
||||||
|
|
||||||
statement = SQLStatement("SET a=1", "sqlite")
|
|
||||||
assert statement.get_settings() == {"a": "1"}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"engine",
|
"engine",
|
||||||
[
|
[
|
||||||
@@ -1924,194 +1877,3 @@ def test_extract_tables_from_jinja_sql(
|
|||||||
)
|
)
|
||||||
== expected
|
== expected
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kustokqlstatement_split_query() -> None:
|
|
||||||
"""
|
|
||||||
Test the `KustoKQLStatement` split method.
|
|
||||||
"""
|
|
||||||
statements = KustoKQLStatement.split_query(
|
|
||||||
"""
|
|
||||||
let totalPagesPerDay = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp)
|
|
||||||
| summarize count() by Day;
|
|
||||||
let materializedScope = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp);
|
|
||||||
let cachedResult = materialize(materializedScope);
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day1 = Day
|
|
||||||
| join kind = inner
|
|
||||||
(
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day2 = Day
|
|
||||||
)
|
|
||||||
on Page
|
|
||||||
| where Day2 > Day1
|
|
||||||
| summarize count() by Day1, Day2
|
|
||||||
| join kind = inner
|
|
||||||
totalPagesPerDay
|
|
||||||
on $left.Day1 == $right.Day
|
|
||||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
||||||
""",
|
|
||||||
"kustokql",
|
|
||||||
)
|
|
||||||
assert len(statements) == 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_kustokqlstatement_with_program() -> None:
|
|
||||||
"""
|
|
||||||
Test the `KustoKQLStatement` split method when the KQL has a program.
|
|
||||||
"""
|
|
||||||
statements = KustoKQLStatement.split_query(
|
|
||||||
"""
|
|
||||||
print program = ```
|
|
||||||
public class Program {
|
|
||||||
public static void Main() {
|
|
||||||
System.Console.WriteLine("Hello!");
|
|
||||||
}
|
|
||||||
}```
|
|
||||||
""",
|
|
||||||
"kustokql",
|
|
||||||
)
|
|
||||||
assert len(statements) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_kustokqlstatement_with_set() -> None:
|
|
||||||
"""
|
|
||||||
Test the `KustoKQLStatement` split method when the KQL has a set command.
|
|
||||||
"""
|
|
||||||
statements = KustoKQLStatement.split_query(
|
|
||||||
"""
|
|
||||||
set querytrace;
|
|
||||||
Events | take 100
|
|
||||||
""",
|
|
||||||
"kustokql",
|
|
||||||
)
|
|
||||||
assert len(statements) == 2
|
|
||||||
assert statements[0].format() == "set querytrace"
|
|
||||||
assert statements[1].format() == "Events | take 100"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"kql,statements",
|
|
||||||
[
|
|
||||||
('print banner=strcat("Hello", ", ", "World!")', 1),
|
|
||||||
(r"print 'O\'Malley\'s'", 1),
|
|
||||||
(r"print 'O\'Mal;ley\'s'", 1),
|
|
||||||
("print ```foo;\nbar;\nbaz;```\n", 1),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
|
||||||
assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_kql() -> None:
|
|
||||||
"""
|
|
||||||
Test the `split_kql` function.
|
|
||||||
"""
|
|
||||||
kql = """
|
|
||||||
let totalPagesPerDay = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp)
|
|
||||||
| summarize count() by Day;
|
|
||||||
let materializedScope = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp);
|
|
||||||
let cachedResult = materialize(materializedScope);
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day1 = Day
|
|
||||||
| join kind = inner
|
|
||||||
(
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day2 = Day
|
|
||||||
)
|
|
||||||
on Page
|
|
||||||
| where Day2 > Day1
|
|
||||||
| summarize count() by Day1, Day2
|
|
||||||
| join kind = inner
|
|
||||||
totalPagesPerDay
|
|
||||||
on $left.Day1 == $right.Day
|
|
||||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
||||||
"""
|
|
||||||
assert split_kql(kql) == [
|
|
||||||
"""
|
|
||||||
let totalPagesPerDay = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp)
|
|
||||||
| summarize count() by Day""",
|
|
||||||
"""
|
|
||||||
let materializedScope = PageViews
|
|
||||||
| summarize by Page, Day = startofday(Timestamp)""",
|
|
||||||
"""
|
|
||||||
let cachedResult = materialize(materializedScope)""",
|
|
||||||
"""
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day1 = Day
|
|
||||||
| join kind = inner
|
|
||||||
(
|
|
||||||
cachedResult
|
|
||||||
| project Page, Day2 = Day
|
|
||||||
)
|
|
||||||
on Page
|
|
||||||
| where Day2 > Day1
|
|
||||||
| summarize count() by Day1, Day2
|
|
||||||
| join kind = inner
|
|
||||||
totalPagesPerDay
|
|
||||||
on $left.Day1 == $right.Day
|
|
||||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
|
||||||
""",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("engine", "sql", "expected"),
|
|
||||||
[
|
|
||||||
# SQLite tests
|
|
||||||
("sqlite", "SELECT 1", False),
|
|
||||||
("sqlite", "INSERT INTO foo VALUES (1)", True),
|
|
||||||
("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True),
|
|
||||||
("sqlite", "DELETE FROM foo WHERE id = 1", True),
|
|
||||||
("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True),
|
|
||||||
("sqlite", "DROP TABLE foo", True),
|
|
||||||
("sqlite", "EXPLAIN SELECT * FROM foo", False),
|
|
||||||
("sqlite", "PRAGMA table_info(foo)", False),
|
|
||||||
("postgresql", "SELECT 1", False),
|
|
||||||
("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True),
|
|
||||||
("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True),
|
|
||||||
("postgresql", "DELETE FROM foo WHERE id = 1", True),
|
|
||||||
("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True),
|
|
||||||
("postgresql", "DROP TABLE foo", True),
|
|
||||||
("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False),
|
|
||||||
("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True),
|
|
||||||
("postgresql", "SHOW search_path", False),
|
|
||||||
("postgresql", "SET search_path TO public", False),
|
|
||||||
(
|
|
||||||
"postgres",
|
|
||||||
"""
|
|
||||||
with source as (
|
|
||||||
select 1 as one
|
|
||||||
)
|
|
||||||
select * from source
|
|
||||||
""",
|
|
||||||
False,
|
|
||||||
),
|
|
||||||
("trino", "SELECT 1", False),
|
|
||||||
("trino", "INSERT INTO foo VALUES (1, 'bar')", True),
|
|
||||||
("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True),
|
|
||||||
("trino", "DELETE FROM foo WHERE id = 1", True),
|
|
||||||
("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True),
|
|
||||||
("trino", "DROP TABLE foo", True),
|
|
||||||
("trino", "EXPLAIN SELECT * FROM foo", False),
|
|
||||||
("trino", "SHOW SCHEMAS", False),
|
|
||||||
("trino", "SET SESSION optimization_level = '3'", False),
|
|
||||||
("kustokql", "tbl | limit 100", False),
|
|
||||||
("kustokql", "let foo = 1; tbl | where bar == foo", False),
|
|
||||||
("kustokql", ".show tables", False),
|
|
||||||
("kustokql", "print 1", False),
|
|
||||||
("kustokql", "set querytrace; Events | take 100", False),
|
|
||||||
("kustokql", ".drop table foo", True),
|
|
||||||
("kustokql", ".set-or-append table foo <| bar", True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
|
||||||
"""
|
|
||||||
Test the `has_mutation` method.
|
|
||||||
"""
|
|
||||||
assert SQLScript(sql, engine).has_mutation() == expected
|
|
||||||
|
|||||||
Reference in New Issue
Block a user