mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
5 Commits
supersetbo
...
remove-mor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
754807b87d | ||
|
|
782f94fe8d | ||
|
|
8e0c00a82e | ||
|
|
c0c8802de9 | ||
|
|
cd3209a600 |
@@ -36,7 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetErrorException
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
|
||||
table.normalize_columns = self._base_model.normalize_columns
|
||||
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
||||
table.is_sqllab_view = True
|
||||
table.sql = ParsedQuery(
|
||||
self._base_model.sql,
|
||||
engine=database.db_engine_spec.engine,
|
||||
).stripped()
|
||||
table.sql = self._base_model.sql.strip().strip(";")
|
||||
db.session.add(table)
|
||||
cols = []
|
||||
for config_ in self._base_model.columns:
|
||||
|
||||
@@ -1778,7 +1778,7 @@ GUEST_TOKEN_VALIDATOR_HOOK = None
|
||||
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
|
||||
# if (
|
||||
# datasource.sql and
|
||||
# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1
|
||||
# len(SQLScript(datasource.sql).tables) == 1
|
||||
# ):
|
||||
# return (
|
||||
# "This virtual dataset queries only one table and therefore could be "
|
||||
|
||||
@@ -67,7 +67,7 @@ from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table
|
||||
from sqlalchemy.sql.elements import ColumnClause, TextClause
|
||||
from sqlalchemy.sql.expression import Label, TextAsFrom
|
||||
from sqlalchemy.sql.expression import Label
|
||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
@@ -104,7 +104,7 @@ from superset.models.helpers import (
|
||||
QueryResult,
|
||||
)
|
||||
from superset.models.slice import Slice
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import (
|
||||
AdhocColumn,
|
||||
AdhocMetric,
|
||||
@@ -1469,34 +1469,13 @@ class SqlaTable(
|
||||
return tbl
|
||||
|
||||
def get_from_clause(
|
||||
self, template_processor: BaseTemplateProcessor | None = None
|
||||
self,
|
||||
template_processor: BaseTemplateProcessor | None = None,
|
||||
) -> tuple[TableClause | Alias, str | None]:
|
||||
"""
|
||||
Return where to select the columns and metrics from. Either a physical table
|
||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||
CTE, the CTE is returned as the second value in the return tuple.
|
||||
"""
|
||||
if not self.is_virtual:
|
||||
return self.get_sqla_table(), None
|
||||
|
||||
from_sql = self.get_rendered_sql(template_processor) + "\n"
|
||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||
if not (
|
||||
parsed_query.is_unknown()
|
||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query must be read-only")
|
||||
)
|
||||
|
||||
cte = self.db_engine_spec.get_cte_query(from_sql)
|
||||
from_clause = (
|
||||
table(self.db_engine_spec.cte_alias)
|
||||
if cte
|
||||
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
||||
)
|
||||
|
||||
return from_clause, cte
|
||||
return super().get_from_clause(template_processor)
|
||||
|
||||
def adhoc_metric_to_sqla(
|
||||
self,
|
||||
|
||||
@@ -63,7 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
|
||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
||||
from superset.sql.parse import SQLScript, Table
|
||||
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
@@ -1737,18 +1737,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def process_statement(cls, statement: str, database: Database) -> str:
|
||||
def process_statement(
|
||||
cls,
|
||||
statement: BaseSQLStatement[Any],
|
||||
database: Database,
|
||||
) -> str:
|
||||
"""
|
||||
Process a SQL statement by stripping and mutating it.
|
||||
Process a SQL statement by mutating it.
|
||||
|
||||
:param statement: A single SQL statement
|
||||
:param database: Database instance
|
||||
:return: Dictionary with different costs
|
||||
"""
|
||||
parsed_query = ParsedQuery(statement, engine=cls.engine)
|
||||
sql = parsed_query.stripped()
|
||||
|
||||
return database.mutate_sql_based_on_config(sql, is_split=True)
|
||||
return database.mutate_sql_based_on_config(str(statement), is_split=True)
|
||||
|
||||
@classmethod
|
||||
def estimate_query_cost( # pylint: disable=too-many-arguments
|
||||
@@ -1773,8 +1774,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"Database does not support cost estimation"
|
||||
)
|
||||
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||
|
||||
with database.get_raw_connection(
|
||||
catalog=catalog,
|
||||
@@ -1788,7 +1788,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
cls.process_statement(statement, database),
|
||||
cursor,
|
||||
)
|
||||
for statement in statements
|
||||
for statement in parsed_script.statements
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -2056,15 +2056,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
logger.error(ex, exc_info=True)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
||||
return (
|
||||
parsed_query.is_select()
|
||||
or parsed_query.is_explain()
|
||||
or parsed_query.is_show()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""
|
||||
@@ -2178,10 +2169,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse_sql(cls, sql: str) -> list[str]:
|
||||
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
||||
|
||||
@classmethod
|
||||
def get_impersonation_key(cls, user: User | None) -> Any:
|
||||
"""
|
||||
|
||||
@@ -36,7 +36,6 @@ from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.sql import sqltypes
|
||||
|
||||
from superset import sql_parse
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
||||
from superset.databases.utils import make_url_safe
|
||||
@@ -44,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
||||
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
||||
from superset.errors import SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql.parse import SQLScript
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils, json
|
||||
@@ -449,8 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
if not cls.get_allow_cost_estimate(extra):
|
||||
raise SupersetException("Database does not support cost estimation")
|
||||
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||
|
||||
with cls.get_engine(
|
||||
database,
|
||||
@@ -463,7 +462,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
cls.process_statement(statement, database),
|
||||
client,
|
||||
)
|
||||
for statement in statements
|
||||
for statement in parsed_script.statements
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -45,7 +45,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.extensions import cache_manager
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -605,15 +605,6 @@ class HiveEngineSpec(PrestoEngineSpec):
|
||||
# otherwise, return no function names to prevent errors
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
||||
return (
|
||||
super().is_readonly_query(parsed_query)
|
||||
or parsed_query.is_set()
|
||||
or parsed_query.is_show()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def has_implicit_cancel(cls) -> bool:
|
||||
"""
|
||||
|
||||
@@ -104,11 +104,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
||||
return parsed_query.sql.lower().startswith("select")
|
||||
|
||||
|
||||
class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
@@ -158,23 +153,6 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""
|
||||
Pessimistic readonly, 100% sure statement won't mutate anything.
|
||||
"""
|
||||
return KustoKqlEngineSpec.is_select_query(
|
||||
parsed_query
|
||||
) or parsed_query.sql.startswith(".show")
|
||||
|
||||
@classmethod
|
||||
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
return not parsed_query.sql.startswith(".")
|
||||
|
||||
@classmethod
|
||||
def parse_sql(cls, sql: str) -> list[str]:
|
||||
"""
|
||||
Kusto supports a single query statement, but it could include sub queries
|
||||
and variables declared via let keyword.
|
||||
"""
|
||||
return [sql]
|
||||
|
||||
@@ -74,6 +74,7 @@ from superset.extensions import (
|
||||
)
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql.parse import SQLScript
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
|
||||
from superset.utils import cache as cache_util, core as utils, json
|
||||
@@ -674,7 +675,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
schema: str | None = None,
|
||||
mutator: Callable[[pd.DataFrame], None] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = self.db_engine_spec.parse_sql(sql)
|
||||
parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine)
|
||||
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
|
||||
engine_url = engine.url
|
||||
|
||||
@@ -691,8 +692,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
with self.get_raw_connection(catalog=catalog, schema=schema) as conn:
|
||||
cursor = conn.cursor()
|
||||
df = None
|
||||
for i, sql_ in enumerate(sqls):
|
||||
sql_ = self.mutate_sql_based_on_config(sql_, is_split=True)
|
||||
for i, statement in enumerate(parsed_script.statements):
|
||||
# pylint: disable=protected-access
|
||||
sql_ = self.mutate_sql_based_on_config(statement._sql, is_split=True)
|
||||
_log_query(sql_)
|
||||
with event_logger.log_context(
|
||||
action="execute_sql",
|
||||
@@ -700,7 +702,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||
object_ref=__name__,
|
||||
):
|
||||
self.db_engine_spec.execute(cursor, sql_, self)
|
||||
if i < len(sqls) - 1:
|
||||
if i < len(parsed_script.statements) - 1:
|
||||
# If it's not the last, we don't keep the results
|
||||
cursor.fetchall()
|
||||
else:
|
||||
|
||||
@@ -72,7 +72,6 @@ from superset.sql.parse import SQLScript
|
||||
from superset.sql_parse import (
|
||||
has_table_query,
|
||||
insert_rls_in_predicate,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
)
|
||||
from superset.superset_typing import (
|
||||
@@ -1039,6 +1038,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
"""
|
||||
Render sql with template engine (Jinja).
|
||||
"""
|
||||
if not self.sql:
|
||||
return ""
|
||||
|
||||
sql = self.sql.strip("\t\r\n; ")
|
||||
if template_processor:
|
||||
try:
|
||||
@@ -1072,13 +1074,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||
CTE, the CTE is returned as the second value in the return tuple.
|
||||
"""
|
||||
|
||||
from_sql = self.get_rendered_sql(template_processor) + "\n"
|
||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||
if not (
|
||||
parsed_query.is_unknown()
|
||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||
):
|
||||
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
|
||||
if parsed_script.has_mutation():
|
||||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query must be read-only")
|
||||
)
|
||||
|
||||
@@ -20,10 +20,11 @@ from __future__ import annotations
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Iterator, TypeVar
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
@@ -226,6 +227,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
@@ -382,6 +389,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
|
||||
return False
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
return isinstance(self._parsed, exp.Select)
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
@@ -431,60 +444,115 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
}
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
class KQLTokenizeState(enum.Enum):
|
||||
"""
|
||||
State machine for splitting a KQL script.
|
||||
State machine for tokenizing a KQL script.
|
||||
|
||||
The state machine keeps track of whether we're inside a string or not, so we
|
||||
don't split the script in a semi-colon that's part of a string.
|
||||
"""
|
||||
|
||||
OUTSIDE_STRING = enum.auto()
|
||||
OUTSIDE = enum.auto()
|
||||
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
||||
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
||||
INSIDE_MULTILINE_STRING = enum.auto()
|
||||
INSIDE_SINGLE_QUOTED_IDENTIFIER = enum.auto()
|
||||
INSIDE_DOUBLE_QUOTED_IDENTIFIER = enum.auto()
|
||||
|
||||
|
||||
def tokenize_kql(kql: str) -> Iterator[str]:
|
||||
"""
|
||||
Tokenize a KQL script.
|
||||
"""
|
||||
valid_identifier_chars = set(string.ascii_letters + string.digits + "_")
|
||||
valid_quoted_identifier_chars = valid_identifier_chars | set(" .-")
|
||||
|
||||
script = kql if kql.endswith(";") else kql + ";"
|
||||
|
||||
cursor = 0
|
||||
while cursor < len(script):
|
||||
rest = script[cursor:]
|
||||
|
||||
# quoted identifiers
|
||||
if rest[:2] in {"['", '["'}:
|
||||
match = "']" if rest[:2] == "['" else '"]'
|
||||
if match not in rest[2:]:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
"kustokql",
|
||||
message="Unclosed quoted identifier",
|
||||
)
|
||||
|
||||
token = rest[: rest.index(match, 2) + 2]
|
||||
|
||||
if any(char not in valid_quoted_identifier_chars for char in token[2:-2]):
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
"kustokql",
|
||||
message="Invalid quoted identifier",
|
||||
)
|
||||
|
||||
yield token
|
||||
cursor += len(token)
|
||||
|
||||
# multi-line strings
|
||||
elif rest[:3] == "```":
|
||||
if "```" not in rest[3:]:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
"kustokql",
|
||||
message="Unclosed multi-line string",
|
||||
)
|
||||
|
||||
token = rest[: rest.index("```", 3) + 3]
|
||||
yield token
|
||||
cursor += len(token)
|
||||
|
||||
# single-quoted strings
|
||||
elif rest[0] in {'"', "'"}:
|
||||
match = rest[0]
|
||||
# find first unescaped quote
|
||||
start = 1
|
||||
while True:
|
||||
if match not in rest[start:]:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
"kustokql",
|
||||
message="Unclosed string",
|
||||
)
|
||||
index = rest.index(match, start)
|
||||
if rest[index - 1] != "\\":
|
||||
break
|
||||
|
||||
start = index + 1
|
||||
|
||||
token = rest[: index + 1]
|
||||
yield token
|
||||
cursor += len(token)
|
||||
|
||||
# identifiers and keywords
|
||||
else:
|
||||
for i, char in enumerate(rest):
|
||||
if char not in valid_identifier_chars:
|
||||
if i > 0:
|
||||
yield rest[:i]
|
||||
yield char
|
||||
cursor += i + 1
|
||||
break
|
||||
|
||||
|
||||
def split_kql(kql: str) -> list[str]:
|
||||
"""
|
||||
Custom function for splitting KQL statements.
|
||||
"""
|
||||
statements = []
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
statement_start = 0
|
||||
script = kql if kql.endswith(";") else kql + ";"
|
||||
for i, character in enumerate(script):
|
||||
if state == KQLSplitState.OUTSIDE_STRING:
|
||||
if character == ";":
|
||||
statements.append(script[statement_start:i])
|
||||
statement_start = i + 1
|
||||
elif character == "'":
|
||||
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
elif character == '"':
|
||||
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
elif character == "`" and script[i - 2 : i] == "``":
|
||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
and character == "'"
|
||||
and script[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
and character == '"'
|
||||
and script[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
and character == "`"
|
||||
and script[i - 2 : i] == "``"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
statements: list[str] = []
|
||||
statement: list[str] = []
|
||||
for token in tokenize_kql(kql):
|
||||
if token == ";":
|
||||
statements.append("".join(statement))
|
||||
statement = []
|
||||
else:
|
||||
statement.append(token)
|
||||
|
||||
return statements
|
||||
|
||||
@@ -506,6 +574,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
details about it.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str,
|
||||
engine: str = "kustokql",
|
||||
ast: str | None = None,
|
||||
):
|
||||
super().__init__(statement, engine, ast)
|
||||
|
||||
@classmethod
|
||||
def split_script(
|
||||
cls,
|
||||
@@ -588,6 +664,56 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
"""
|
||||
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
if not self._parsed or self.is_mutating():
|
||||
return False
|
||||
|
||||
# strip comments
|
||||
kql = "\n".join(
|
||||
line
|
||||
for line in self._parsed.split("\n")
|
||||
if not line.strip().startswith("//")
|
||||
).strip()
|
||||
|
||||
first_token = next(tokenize_kql(kql), None)
|
||||
if not first_token:
|
||||
return False
|
||||
|
||||
return first_token == "|" or self._is_identifier(first_token)
|
||||
|
||||
@staticmethod
|
||||
def _is_identifier(identifier: str) -> bool:
|
||||
"""
|
||||
Validates if a given string is a valid KQL identifier.
|
||||
|
||||
From the documentation:
|
||||
|
||||
Identifiers are case-sensitive. Database names are case-insensitive, and
|
||||
therefore an exception to this rule.
|
||||
Identifiers must be between 1 and 1024 characters long.
|
||||
Identifiers may contain letters, digits, and underscores (_).
|
||||
Identifiers may contain certain special characters: spaces, dots (.), and
|
||||
dashes (-). For information on how to reference identifiers with special
|
||||
characters, see Reference identifiers in queries.
|
||||
|
||||
"""
|
||||
valid_chars = set(string.ascii_letters + string.digits + "_")
|
||||
|
||||
# Identifiers names that (1) include special character, (2) are language
|
||||
# keywords, or (3) are literals must be enclosed using [' and '] or [" and "].
|
||||
if (identifier.startswith("['") and identifier.endswith("']")) or (
|
||||
identifier.startswith('["') and identifier.endswith('"]')
|
||||
):
|
||||
identifier = identifier[2:-2]
|
||||
valid_chars.update(" .-")
|
||||
|
||||
return 1 <= len(identifier) <= 1024 and all(
|
||||
char in valid_chars for char in identifier
|
||||
)
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
@@ -642,6 +768,24 @@ class SQLScript:
|
||||
"""
|
||||
return any(statement.is_mutating() for statement in self.statements)
|
||||
|
||||
def is_valid_ctas(self) -> bool:
|
||||
"""
|
||||
Check if the script contains a valid CTAS statement.
|
||||
|
||||
CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last
|
||||
statement is a `SELECT`.
|
||||
"""
|
||||
return self.statements[-1].is_select()
|
||||
|
||||
def is_valid_cvas(self) -> bool:
|
||||
"""
|
||||
Check if the script contains a valid CVAS statement.
|
||||
|
||||
CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single
|
||||
`SELECT` statement.
|
||||
"""
|
||||
return len(self.statements) == 1 and self.statements[0].is_select()
|
||||
|
||||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
@@ -650,7 +794,7 @@ def extract_tables_from_statement(
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
Please note that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
|
||||
@@ -20,11 +20,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
from contextlib import closing
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from superset import app
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql.parse import SQLScript, SQLStatement
|
||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||
from superset.utils.core import QuerySource
|
||||
|
||||
@@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
@classmethod
|
||||
def validate_statement(
|
||||
cls,
|
||||
statement: str,
|
||||
statement: SQLStatement,
|
||||
database: Database,
|
||||
cursor: Any,
|
||||
) -> SQLValidationAnnotation | None:
|
||||
# pylint: disable=too-many-locals
|
||||
db_engine_spec = database.db_engine_spec
|
||||
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
|
||||
sql = parsed_query.stripped()
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
sql = database.mutate_sql_based_on_config(sql)
|
||||
sql = database.mutate_sql_based_on_config(str(statement))
|
||||
|
||||
# Transform the final statement to an explain call before sending it on
|
||||
# to presto to validate
|
||||
@@ -155,10 +153,9 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||
VALIDATE) SELECT 1 FROM default.mytable.
|
||||
"""
|
||||
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine)
|
||||
|
||||
logger.info("Validating %i statement(s)", len(statements))
|
||||
logger.info("Validating %i statement(s)", len(parsed_script.statements))
|
||||
# todo(hughhh): update this to use new database.get_raw_connection()
|
||||
# this function keeps stalling CI
|
||||
with database.get_sqla_engine(
|
||||
@@ -171,8 +168,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||
annotations: list[SQLValidationAnnotation] = []
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
cursor = conn.cursor()
|
||||
for statement in parsed_query.get_statements():
|
||||
annotation = cls.validate_statement(statement, database, cursor)
|
||||
for statement in parsed_script.statements:
|
||||
annotation = cls.validate_statement(
|
||||
cast(SQLStatement, statement),
|
||||
database,
|
||||
cursor,
|
||||
)
|
||||
if annotation:
|
||||
annotations.append(annotation)
|
||||
logger.debug("Validation found %i error(s)", len(annotations))
|
||||
|
||||
@@ -26,7 +26,6 @@ from jinja2.meta import find_undeclared_variables
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.sql_lab.execute import SqlQueryRender
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sqllab.exceptions import SqlLabException
|
||||
from superset.utils import core as utils
|
||||
|
||||
@@ -58,12 +57,9 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
||||
database=query_model.database, query=query_model
|
||||
)
|
||||
|
||||
parsed_query = ParsedQuery(
|
||||
query_model.sql,
|
||||
engine=query_model.database.db_engine_spec.engine,
|
||||
)
|
||||
rendered_query = sql_template_processor.process_template(
|
||||
parsed_query.stripped(), **execution_context.template_params
|
||||
query_model.sql.strip().strip(";"),
|
||||
**execution_context.template_params,
|
||||
)
|
||||
self._validate(execution_context, rendered_query, sql_template_processor)
|
||||
return rendered_query
|
||||
|
||||
@@ -30,7 +30,7 @@ from superset.db_engine_specs.base import (
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
from tests.integration_tests.test_app import app
|
||||
@@ -310,20 +310,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
||||
)
|
||||
|
||||
|
||||
def test_is_readonly():
|
||||
def is_readonly(sql: str) -> bool:
|
||||
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
|
||||
|
||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
||||
assert not is_readonly("SET hivevar:desc='Legislators'")
|
||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
||||
assert is_readonly("EXPLAIN SELECT 1")
|
||||
assert is_readonly("SELECT 1")
|
||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
||||
assert is_readonly("SHOW CATALOGS")
|
||||
assert is_readonly("SHOW TABLES")
|
||||
|
||||
|
||||
def test_time_grain_denylist():
|
||||
config = app.config.copy()
|
||||
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]
|
||||
|
||||
@@ -23,7 +23,7 @@ from sqlalchemy.sql import select
|
||||
|
||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
@@ -222,19 +222,6 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
||||
app.config = config
|
||||
|
||||
|
||||
def test_is_readonly():
|
||||
def is_readonly(sql: str) -> bool:
|
||||
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))
|
||||
|
||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
||||
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
|
||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
||||
assert is_readonly("SET hivevar:desc='Legislators'")
|
||||
assert is_readonly("EXPLAIN SELECT 1")
|
||||
assert is_readonly("SELECT 1")
|
||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema,upload_prefix",
|
||||
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],
|
||||
|
||||
@@ -25,7 +25,7 @@ from sqlalchemy.sql import select
|
||||
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
@@ -1172,19 +1172,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
]
|
||||
|
||||
|
||||
def test_is_readonly():
|
||||
def is_readonly(sql: str) -> bool:
|
||||
return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
|
||||
|
||||
assert not is_readonly("SET hivevar:desc='Legislators'")
|
||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
||||
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
|
||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
||||
assert is_readonly("EXPLAIN SELECT 1")
|
||||
assert is_readonly("SELECT 1")
|
||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
||||
|
||||
|
||||
def test_get_catalog_names(app_context: AppContext) -> None:
|
||||
"""
|
||||
Test the ``get_catalog_names`` method.
|
||||
|
||||
@@ -49,32 +49,6 @@ def test_get_text_clause_with_colon() -> None:
|
||||
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
|
||||
|
||||
|
||||
def test_parse_sql_single_statement() -> None:
|
||||
"""
|
||||
`parse_sql` should properly strip leading and trailing spaces and semicolons
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
queries = BaseEngineSpec.parse_sql(" SELECT foo FROM tbl ; ")
|
||||
assert queries == ["SELECT foo FROM tbl"]
|
||||
|
||||
|
||||
def test_parse_sql_multi_statement() -> None:
|
||||
"""
|
||||
For string with multiple SQL-statements `parse_sql` method should return list
|
||||
where each element represents the single SQL-statement
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;")
|
||||
assert queries == [
|
||||
"SELECT foo FROM tbl1",
|
||||
"SELECT bar FROM tbl2",
|
||||
]
|
||||
|
||||
|
||||
def test_validate_db_uri(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Ensures that the `validate_database_uri` method invokes the validator correctly
|
||||
|
||||
@@ -24,91 +24,6 @@ from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected",
|
||||
[
|
||||
("SELECT foo FROM tbl", True),
|
||||
("SHOW TABLES", False),
|
||||
("EXPLAIN SELECT foo FROM tbl", False),
|
||||
("INSERT INTO tbl (foo) VALUES (1)", False),
|
||||
],
|
||||
)
|
||||
def test_sql_is_readonly_query(sql: str, expected: bool) -> None:
|
||||
"""
|
||||
Make sure that SQL dialect consider only SELECT statements as read-only
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
parsed_query = ParsedQuery(sql)
|
||||
is_readonly = KustoSqlEngineSpec.is_readonly_query(parsed_query)
|
||||
|
||||
assert expected == is_readonly
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql,expected",
|
||||
[
|
||||
("tbl | limit 100", True),
|
||||
("let foo = 1; tbl | where bar == foo", True),
|
||||
(".show tables", False),
|
||||
],
|
||||
)
|
||||
def test_kql_is_select_query(kql: str, expected: bool) -> None:
|
||||
"""
|
||||
Make sure that KQL dialect consider only statements that do not start with "." (dot)
|
||||
as a SELECT statements
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
parsed_query = ParsedQuery(kql)
|
||||
is_select = KustoKqlEngineSpec.is_select_query(parsed_query)
|
||||
|
||||
assert expected == is_select
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql,expected",
|
||||
[
|
||||
("tbl | limit 100", True),
|
||||
("let foo = 1; tbl | where bar == foo", True),
|
||||
(".show tables", True),
|
||||
("print 1", True),
|
||||
("set querytrace; Events | take 100", True),
|
||||
(".drop table foo", False),
|
||||
(".set-or-append table foo <| bar", False),
|
||||
],
|
||||
)
|
||||
def test_kql_is_readonly_query(kql: str, expected: bool) -> None:
|
||||
"""
|
||||
Make sure that KQL dialect consider only SELECT statements as read-only
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
parsed_query = ParsedQuery(kql)
|
||||
is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query)
|
||||
|
||||
assert expected == is_readonly
|
||||
|
||||
|
||||
def test_kql_parse_sql() -> None:
|
||||
"""
|
||||
parse_sql method should always return a list with a single element
|
||||
which is an original query
|
||||
"""
|
||||
|
||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
||||
|
||||
queries = KustoKqlEngineSpec.parse_sql("let foo = 1; tbl | where bar == foo")
|
||||
|
||||
assert queries == ["let foo = 1; tbl | where bar == foo"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_type,expected_result",
|
||||
[
|
||||
|
||||
@@ -945,6 +945,32 @@ on $left.Day1 == $right.Day
|
||||
("kustokql", "set querytrace; Events | take 100", False),
|
||||
("kustokql", ".drop table foo", True),
|
||||
("kustokql", ".set-or-append table foo <| bar", True),
|
||||
("kustosql", "SELECT foo FROM tbl", False),
|
||||
("kustosql", "SHOW TABLES", False),
|
||||
("kustosql", "EXPLAIN SELECT foo FROM tbl", False),
|
||||
("kustosql", "INSERT INTO tbl (foo) VALUES (1)", True),
|
||||
("base", "SHOW LOCKS test EXTENDED", False),
|
||||
("base", "SET hivevar:desc='Legislators'", False),
|
||||
("base", "UPDATE t1 SET col1 = NULL", True),
|
||||
("base", "EXPLAIN SELECT 1", False),
|
||||
("base", "SELECT 1", False),
|
||||
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||
("base", "SHOW CATALOGS", False),
|
||||
("base", "SHOW TABLES", False),
|
||||
("hive", "UPDATE t1 SET col1 = NULL", True),
|
||||
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
||||
("hive", "SHOW LOCKS test EXTENDED", False),
|
||||
("hive", "SET hivevar:desc='Legislators'", False),
|
||||
("hive", "EXPLAIN SELECT 1", False),
|
||||
("hive", "SELECT 1", False),
|
||||
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||
("presto", "SET hivevar:desc='Legislators'", False),
|
||||
("presto", "UPDATE t1 SET col1 = NULL", True),
|
||||
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
||||
("presto", "SHOW LOCKS test EXTENDED", False),
|
||||
("presto", "EXPLAIN SELECT 1", False),
|
||||
("presto", "SELECT 1", False),
|
||||
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||
],
|
||||
)
|
||||
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
||||
@@ -1042,9 +1068,106 @@ def test_custom_dialect(app: None) -> None:
|
||||
)
|
||||
def test_is_mutating(engine: str) -> None:
|
||||
"""
|
||||
Tests for `is_mutating`.
|
||||
Global tests for `is_mutating`, covering all supported engines.
|
||||
"""
|
||||
assert not SQLStatement(
|
||||
"with source as ( select 1 as one ) select * from source",
|
||||
engine=engine,
|
||||
).is_mutating()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"identifier,expected",
|
||||
[
|
||||
# Rule: Identifiers are case-sensitive
|
||||
("myTable", True),
|
||||
("MYTABLE", True),
|
||||
("MyTable", True),
|
||||
# Rule: Identifiers must be between 1 and 1024 characters long
|
||||
("a", True),
|
||||
("a" * 1024, True),
|
||||
("a" * 1025, False),
|
||||
("", False),
|
||||
# Rule: Identifiers may contain letters, digits, and underscores
|
||||
("My_Table_123", True),
|
||||
("123Table", True),
|
||||
# Rule: Identifiers may contain special characters: spaces, dots, dashes (when quoted)
|
||||
("['My Table']", True),
|
||||
("['My-Table']", True),
|
||||
("['My.Table']", True),
|
||||
("['Table-']", True),
|
||||
("['My Table Name']", True),
|
||||
("['My!Table']", False),
|
||||
("['MyTable ']", True),
|
||||
(" MyTable", False),
|
||||
("MyTable ", False),
|
||||
# Rule: Non-special identifiers don't require quoting
|
||||
("MyTable", True),
|
||||
("My-Table", False),
|
||||
# Invalid quoting
|
||||
("['Invalid]", False),
|
||||
("['Invalid'Name']", False),
|
||||
("['']", False),
|
||||
# Rule: Literal identifiers or language keywords
|
||||
("['select']", True),
|
||||
("select", True),
|
||||
],
|
||||
)
|
||||
def test_is_kql_identifier(identifier: str, expected: bool):
|
||||
"""
|
||||
Tests the _is_identifier method for various valid and invalid cases.
|
||||
"""
|
||||
assert KustoKQLStatement._is_identifier(identifier) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql,expected",
|
||||
[
|
||||
# Simple SELECT-like statements (non-mutating queries)
|
||||
("MyTable | count", True),
|
||||
("MyTable", True),
|
||||
("| count", True),
|
||||
("tbl | limit 100", True),
|
||||
(".show tables", False),
|
||||
# With comments (ensure comments are stripped out)
|
||||
("// Comment only", False),
|
||||
("MyTable // trailing comment", True),
|
||||
("// leading comment\nMyTable", True),
|
||||
("MyTable\n// intermediate comment\n| count", True),
|
||||
# Mutating query (should return False)
|
||||
(".drop MyTable", False),
|
||||
(
|
||||
".update MyTable set Column1 = 100",
|
||||
False,
|
||||
),
|
||||
(".alter MyTable", False),
|
||||
# Edge cases for first token
|
||||
("", False),
|
||||
(" ", False),
|
||||
(".command", False),
|
||||
("['My Table']", True),
|
||||
# Complex multi-line queries
|
||||
(
|
||||
"""
|
||||
// Initial comment
|
||||
MyTable
|
||||
| where Column1 > 100
|
||||
""",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"""
|
||||
MyTable
|
||||
| where Column1 > 100
|
||||
| summarize by Column2
|
||||
// Final comment
|
||||
""",
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_kql_is_select(kql: str, expected: bool):
|
||||
"""
|
||||
Tests the is_select method for various valid and invalid cases.
|
||||
"""
|
||||
assert KustoKQLStatement(kql).is_select() == expected
|
||||
|
||||
Reference in New Issue
Block a user