mirror of
https://github.com/apache/superset.git
synced 2026-04-30 13:34:20 +00:00
Compare commits
5 Commits
fix-interm
...
remove-mor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
754807b87d | ||
|
|
782f94fe8d | ||
|
|
8e0c00a82e | ||
|
|
c0c8802de9 | ||
|
|
cd3209a600 |
@@ -36,7 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
|||||||
from superset.exceptions import SupersetErrorException
|
from superset.exceptions import SupersetErrorException
|
||||||
from superset.extensions import db
|
from superset.extensions import db
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from superset.utils.decorators import on_error, transaction
|
from superset.utils.decorators import on_error, transaction
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
|
|||||||
table.normalize_columns = self._base_model.normalize_columns
|
table.normalize_columns = self._base_model.normalize_columns
|
||||||
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
||||||
table.is_sqllab_view = True
|
table.is_sqllab_view = True
|
||||||
table.sql = ParsedQuery(
|
table.sql = self._base_model.sql.strip().strip(";")
|
||||||
self._base_model.sql,
|
|
||||||
engine=database.db_engine_spec.engine,
|
|
||||||
).stripped()
|
|
||||||
db.session.add(table)
|
db.session.add(table)
|
||||||
cols = []
|
cols = []
|
||||||
for config_ in self._base_model.columns:
|
for config_ in self._base_model.columns:
|
||||||
|
|||||||
@@ -1778,7 +1778,7 @@ GUEST_TOKEN_VALIDATOR_HOOK = None
|
|||||||
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
|
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
|
||||||
# if (
|
# if (
|
||||||
# datasource.sql and
|
# datasource.sql and
|
||||||
# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1
|
# len(SQLScript(datasource.sql).tables) == 1
|
||||||
# ):
|
# ):
|
||||||
# return (
|
# return (
|
||||||
# "This virtual dataset queries only one table and therefore could be "
|
# "This virtual dataset queries only one table and therefore could be "
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ from sqlalchemy.orm.mapper import Mapper
|
|||||||
from sqlalchemy.schema import UniqueConstraint
|
from sqlalchemy.schema import UniqueConstraint
|
||||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table
|
from sqlalchemy.sql import column, ColumnElement, literal_column, table
|
||||||
from sqlalchemy.sql.elements import ColumnClause, TextClause
|
from sqlalchemy.sql.elements import ColumnClause, TextClause
|
||||||
from sqlalchemy.sql.expression import Label, TextAsFrom
|
from sqlalchemy.sql.expression import Label
|
||||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||||
|
|
||||||
from superset import app, db, is_feature_enabled, security_manager
|
from superset import app, db, is_feature_enabled, security_manager
|
||||||
@@ -104,7 +104,7 @@ from superset.models.helpers import (
|
|||||||
QueryResult,
|
QueryResult,
|
||||||
)
|
)
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
AdhocColumn,
|
AdhocColumn,
|
||||||
AdhocMetric,
|
AdhocMetric,
|
||||||
@@ -1469,34 +1469,13 @@ class SqlaTable(
|
|||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
def get_from_clause(
|
def get_from_clause(
|
||||||
self, template_processor: BaseTemplateProcessor | None = None
|
self,
|
||||||
|
template_processor: BaseTemplateProcessor | None = None,
|
||||||
) -> tuple[TableClause | Alias, str | None]:
|
) -> tuple[TableClause | Alias, str | None]:
|
||||||
"""
|
|
||||||
Return where to select the columns and metrics from. Either a physical table
|
|
||||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
|
||||||
CTE, the CTE is returned as the second value in the return tuple.
|
|
||||||
"""
|
|
||||||
if not self.is_virtual:
|
if not self.is_virtual:
|
||||||
return self.get_sqla_table(), None
|
return self.get_sqla_table(), None
|
||||||
|
|
||||||
from_sql = self.get_rendered_sql(template_processor) + "\n"
|
return super().get_from_clause(template_processor)
|
||||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
|
||||||
if not (
|
|
||||||
parsed_query.is_unknown()
|
|
||||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
|
||||||
):
|
|
||||||
raise QueryObjectValidationError(
|
|
||||||
_("Virtual dataset query must be read-only")
|
|
||||||
)
|
|
||||||
|
|
||||||
cte = self.db_engine_spec.get_cte_query(from_sql)
|
|
||||||
from_clause = (
|
|
||||||
table(self.db_engine_spec.cte_alias)
|
|
||||||
if cte
|
|
||||||
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
|
||||||
)
|
|
||||||
|
|
||||||
return from_clause, cte
|
|
||||||
|
|
||||||
def adhoc_metric_to_sqla(
|
def adhoc_metric_to_sqla(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
|
|||||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
||||||
from superset.sql.parse import SQLScript, Table
|
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
|
||||||
from superset.sql_parse import ParsedQuery
|
from superset.sql_parse import ParsedQuery
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
OAuth2ClientConfig,
|
OAuth2ClientConfig,
|
||||||
@@ -1737,18 +1737,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_statement(cls, statement: str, database: Database) -> str:
|
def process_statement(
|
||||||
|
cls,
|
||||||
|
statement: BaseSQLStatement[Any],
|
||||||
|
database: Database,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Process a SQL statement by stripping and mutating it.
|
Process a SQL statement by mutating it.
|
||||||
|
|
||||||
:param statement: A single SQL statement
|
:param statement: A single SQL statement
|
||||||
:param database: Database instance
|
:param database: Database instance
|
||||||
:return: Dictionary with different costs
|
:return: Dictionary with different costs
|
||||||
"""
|
"""
|
||||||
parsed_query = ParsedQuery(statement, engine=cls.engine)
|
return database.mutate_sql_based_on_config(str(statement), is_split=True)
|
||||||
sql = parsed_query.stripped()
|
|
||||||
|
|
||||||
return database.mutate_sql_based_on_config(sql, is_split=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def estimate_query_cost( # pylint: disable=too-many-arguments
|
def estimate_query_cost( # pylint: disable=too-many-arguments
|
||||||
@@ -1773,8 +1774,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
"Database does not support cost estimation"
|
"Database does not support cost estimation"
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||||
statements = parsed_query.get_statements()
|
|
||||||
|
|
||||||
with database.get_raw_connection(
|
with database.get_raw_connection(
|
||||||
catalog=catalog,
|
catalog=catalog,
|
||||||
@@ -1788,7 +1788,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
cls.process_statement(statement, database),
|
cls.process_statement(statement, database),
|
||||||
cursor,
|
cursor,
|
||||||
)
|
)
|
||||||
for statement in statements
|
for statement in parsed_script.statements
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -2056,15 +2056,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
logger.error(ex, exc_info=True)
|
logger.error(ex, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
|
||||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
|
||||||
return (
|
|
||||||
parsed_query.is_select()
|
|
||||||
or parsed_query.is_explain()
|
|
||||||
or parsed_query.is_show()
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -2178,10 +2169,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_sql(cls, sql: str) -> list[str]:
|
|
||||||
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_impersonation_key(cls, user: User | None) -> Any:
|
def get_impersonation_key(cls, user: User | None) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from sqlalchemy.engine.reflection import Inspector
|
|||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
from sqlalchemy.sql import sqltypes
|
from sqlalchemy.sql import sqltypes
|
||||||
|
|
||||||
from superset import sql_parse
|
|
||||||
from superset.constants import TimeGrain
|
from superset.constants import TimeGrain
|
||||||
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
@@ -44,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
|
|||||||
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
|
||||||
from superset.errors import SupersetError, SupersetErrorType
|
from superset.errors import SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
|
from superset.sql.parse import SQLScript
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import ResultSetColumnType
|
from superset.superset_typing import ResultSetColumnType
|
||||||
from superset.utils import core as utils, json
|
from superset.utils import core as utils, json
|
||||||
@@ -449,8 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||||||
if not cls.get_allow_cost_estimate(extra):
|
if not cls.get_allow_cost_estimate(extra):
|
||||||
raise SupersetException("Database does not support cost estimation")
|
raise SupersetException("Database does not support cost estimation")
|
||||||
|
|
||||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
parsed_script = SQLScript(sql, engine=cls.engine)
|
||||||
statements = parsed_query.get_statements()
|
|
||||||
|
|
||||||
with cls.get_engine(
|
with cls.get_engine(
|
||||||
database,
|
database,
|
||||||
@@ -463,7 +462,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||||||
cls.process_statement(statement, database),
|
cls.process_statement(statement, database),
|
||||||
client,
|
client,
|
||||||
)
|
)
|
||||||
for statement in statements
|
for statement in parsed_script.statements
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec
|
|||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
from superset.extensions import cache_manager
|
from superset.extensions import cache_manager
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import ResultSetColumnType
|
from superset.superset_typing import ResultSetColumnType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -605,15 +605,6 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||||||
# otherwise, return no function names to prevent errors
|
# otherwise, return no function names to prevent errors
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
|
||||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
|
||||||
return (
|
|
||||||
super().is_readonly_query(parsed_query)
|
|
||||||
or parsed_query.is_set()
|
|
||||||
or parsed_query.is_show()
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_implicit_cancel(cls) -> bool:
|
def has_implicit_cancel(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -104,11 +104,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||||||
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
|
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
|
||||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
|
||||||
return parsed_query.sql.lower().startswith("select")
|
|
||||||
|
|
||||||
|
|
||||||
class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||||
limit_method = LimitMethod.WRAP_SQL
|
limit_method = LimitMethod.WRAP_SQL
|
||||||
@@ -158,23 +153,6 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
|
||||||
"""
|
|
||||||
Pessimistic readonly, 100% sure statement won't mutate anything.
|
|
||||||
"""
|
|
||||||
return KustoKqlEngineSpec.is_select_query(
|
|
||||||
parsed_query
|
|
||||||
) or parsed_query.sql.startswith(".show")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||||
return not parsed_query.sql.startswith(".")
|
return not parsed_query.sql.startswith(".")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse_sql(cls, sql: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
Kusto supports a single query statement, but it could include sub queries
|
|
||||||
and variables declared via let keyword.
|
|
||||||
"""
|
|
||||||
return [sql]
|
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ from superset.extensions import (
|
|||||||
)
|
)
|
||||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||||
from superset.result_set import SupersetResultSet
|
from superset.result_set import SupersetResultSet
|
||||||
|
from superset.sql.parse import SQLScript
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
|
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
|
||||||
from superset.utils import cache as cache_util, core as utils, json
|
from superset.utils import cache as cache_util, core as utils, json
|
||||||
@@ -674,7 +675,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
schema: str | None = None,
|
schema: str | None = None,
|
||||||
mutator: Callable[[pd.DataFrame], None] | None = None,
|
mutator: Callable[[pd.DataFrame], None] | None = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
sqls = self.db_engine_spec.parse_sql(sql)
|
parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine)
|
||||||
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
|
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
|
||||||
engine_url = engine.url
|
engine_url = engine.url
|
||||||
|
|
||||||
@@ -691,8 +692,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
with self.get_raw_connection(catalog=catalog, schema=schema) as conn:
|
with self.get_raw_connection(catalog=catalog, schema=schema) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
df = None
|
df = None
|
||||||
for i, sql_ in enumerate(sqls):
|
for i, statement in enumerate(parsed_script.statements):
|
||||||
sql_ = self.mutate_sql_based_on_config(sql_, is_split=True)
|
# pylint: disable=protected-access
|
||||||
|
sql_ = self.mutate_sql_based_on_config(statement._sql, is_split=True)
|
||||||
_log_query(sql_)
|
_log_query(sql_)
|
||||||
with event_logger.log_context(
|
with event_logger.log_context(
|
||||||
action="execute_sql",
|
action="execute_sql",
|
||||||
@@ -700,7 +702,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||||||
object_ref=__name__,
|
object_ref=__name__,
|
||||||
):
|
):
|
||||||
self.db_engine_spec.execute(cursor, sql_, self)
|
self.db_engine_spec.execute(cursor, sql_, self)
|
||||||
if i < len(sqls) - 1:
|
if i < len(parsed_script.statements) - 1:
|
||||||
# If it's not the last, we don't keep the results
|
# If it's not the last, we don't keep the results
|
||||||
cursor.fetchall()
|
cursor.fetchall()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -72,7 +72,6 @@ from superset.sql.parse import SQLScript
|
|||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
has_table_query,
|
has_table_query,
|
||||||
insert_rls_in_predicate,
|
insert_rls_in_predicate,
|
||||||
ParsedQuery,
|
|
||||||
sanitize_clause,
|
sanitize_clause,
|
||||||
)
|
)
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
@@ -1039,6 +1038,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||||||
"""
|
"""
|
||||||
Render sql with template engine (Jinja).
|
Render sql with template engine (Jinja).
|
||||||
"""
|
"""
|
||||||
|
if not self.sql:
|
||||||
|
return ""
|
||||||
|
|
||||||
sql = self.sql.strip("\t\r\n; ")
|
sql = self.sql.strip("\t\r\n; ")
|
||||||
if template_processor:
|
if template_processor:
|
||||||
try:
|
try:
|
||||||
@@ -1072,13 +1074,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||||
CTE, the CTE is returned as the second value in the return tuple.
|
CTE, the CTE is returned as the second value in the return tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from_sql = self.get_rendered_sql(template_processor) + "\n"
|
from_sql = self.get_rendered_sql(template_processor) + "\n"
|
||||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
|
||||||
if not (
|
if parsed_script.has_mutation():
|
||||||
parsed_query.is_unknown()
|
|
||||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
|
||||||
):
|
|
||||||
raise QueryObjectValidationError(
|
raise QueryObjectValidationError(
|
||||||
_("Virtual dataset query must be read-only")
|
_("Virtual dataset query must be read-only")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,10 +20,11 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import string
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, Iterator, TypeVar
|
||||||
|
|
||||||
import sqlglot
|
import sqlglot
|
||||||
import sqlparse
|
import sqlparse
|
||||||
@@ -226,6 +227,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_select(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement is a `SELECT` statement.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.format()
|
return self.format()
|
||||||
|
|
||||||
@@ -382,6 +389,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_select(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement is a `SELECT` statement.
|
||||||
|
"""
|
||||||
|
return isinstance(self._parsed, exp.Select)
|
||||||
|
|
||||||
def format(self, comments: bool = True) -> str:
|
def format(self, comments: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
Pretty-format the SQL statement.
|
Pretty-format the SQL statement.
|
||||||
@@ -431,60 +444,115 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class KQLSplitState(enum.Enum):
|
class KQLTokenizeState(enum.Enum):
|
||||||
"""
|
"""
|
||||||
State machine for splitting a KQL script.
|
State machine for tokenizing a KQL script.
|
||||||
|
|
||||||
The state machine keeps track of whether we're inside a string or not, so we
|
The state machine keeps track of whether we're inside a string or not, so we
|
||||||
don't split the script in a semi-colon that's part of a string.
|
don't split the script in a semi-colon that's part of a string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OUTSIDE_STRING = enum.auto()
|
OUTSIDE = enum.auto()
|
||||||
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
|
||||||
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
|
||||||
INSIDE_MULTILINE_STRING = enum.auto()
|
INSIDE_MULTILINE_STRING = enum.auto()
|
||||||
|
INSIDE_SINGLE_QUOTED_IDENTIFIER = enum.auto()
|
||||||
|
INSIDE_DOUBLE_QUOTED_IDENTIFIER = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_kql(kql: str) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Tokenize a KQL script.
|
||||||
|
"""
|
||||||
|
valid_identifier_chars = set(string.ascii_letters + string.digits + "_")
|
||||||
|
valid_quoted_identifier_chars = valid_identifier_chars | set(" .-")
|
||||||
|
|
||||||
|
script = kql if kql.endswith(";") else kql + ";"
|
||||||
|
|
||||||
|
cursor = 0
|
||||||
|
while cursor < len(script):
|
||||||
|
rest = script[cursor:]
|
||||||
|
|
||||||
|
# quoted identifiers
|
||||||
|
if rest[:2] in {"['", '["'}:
|
||||||
|
match = "']" if rest[:2] == "['" else '"]'
|
||||||
|
if match not in rest[2:]:
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
"kustokql",
|
||||||
|
message="Unclosed quoted identifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = rest[: rest.index(match, 2) + 2]
|
||||||
|
|
||||||
|
if any(char not in valid_quoted_identifier_chars for char in token[2:-2]):
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
"kustokql",
|
||||||
|
message="Invalid quoted identifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
yield token
|
||||||
|
cursor += len(token)
|
||||||
|
|
||||||
|
# multi-line strings
|
||||||
|
elif rest[:3] == "```":
|
||||||
|
if "```" not in rest[3:]:
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
"kustokql",
|
||||||
|
message="Unclosed multi-line string",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = rest[: rest.index("```", 3) + 3]
|
||||||
|
yield token
|
||||||
|
cursor += len(token)
|
||||||
|
|
||||||
|
# single-quoted strings
|
||||||
|
elif rest[0] in {'"', "'"}:
|
||||||
|
match = rest[0]
|
||||||
|
# find first unescaped quote
|
||||||
|
start = 1
|
||||||
|
while True:
|
||||||
|
if match not in rest[start:]:
|
||||||
|
raise SupersetParseError(
|
||||||
|
script,
|
||||||
|
"kustokql",
|
||||||
|
message="Unclosed string",
|
||||||
|
)
|
||||||
|
index = rest.index(match, start)
|
||||||
|
if rest[index - 1] != "\\":
|
||||||
|
break
|
||||||
|
|
||||||
|
start = index + 1
|
||||||
|
|
||||||
|
token = rest[: index + 1]
|
||||||
|
yield token
|
||||||
|
cursor += len(token)
|
||||||
|
|
||||||
|
# identifiers and keywords
|
||||||
|
else:
|
||||||
|
for i, char in enumerate(rest):
|
||||||
|
if char not in valid_identifier_chars:
|
||||||
|
if i > 0:
|
||||||
|
yield rest[:i]
|
||||||
|
yield char
|
||||||
|
cursor += i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def split_kql(kql: str) -> list[str]:
|
def split_kql(kql: str) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Custom function for splitting KQL statements.
|
Custom function for splitting KQL statements.
|
||||||
"""
|
"""
|
||||||
statements = []
|
statements: list[str] = []
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
statement: list[str] = []
|
||||||
statement_start = 0
|
for token in tokenize_kql(kql):
|
||||||
script = kql if kql.endswith(";") else kql + ";"
|
if token == ";":
|
||||||
for i, character in enumerate(script):
|
statements.append("".join(statement))
|
||||||
if state == KQLSplitState.OUTSIDE_STRING:
|
statement = []
|
||||||
if character == ";":
|
else:
|
||||||
statements.append(script[statement_start:i])
|
statement.append(token)
|
||||||
statement_start = i + 1
|
|
||||||
elif character == "'":
|
|
||||||
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
|
||||||
elif character == '"':
|
|
||||||
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
|
||||||
elif character == "`" and script[i - 2 : i] == "``":
|
|
||||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
|
||||||
and character == "'"
|
|
||||||
and script[i - 1] != "\\"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
|
||||||
and character == '"'
|
|
||||||
and script[i - 1] != "\\"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
elif (
|
|
||||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
|
||||||
and character == "`"
|
|
||||||
and script[i - 2 : i] == "``"
|
|
||||||
):
|
|
||||||
state = KQLSplitState.OUTSIDE_STRING
|
|
||||||
|
|
||||||
return statements
|
return statements
|
||||||
|
|
||||||
@@ -506,6 +574,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||||||
details about it.
|
details about it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
statement: str,
|
||||||
|
engine: str = "kustokql",
|
||||||
|
ast: str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(statement, engine, ast)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def split_script(
|
def split_script(
|
||||||
cls,
|
cls,
|
||||||
@@ -588,6 +664,56 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
|||||||
"""
|
"""
|
||||||
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
|
||||||
|
|
||||||
|
def is_select(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the statement is a `SELECT` statement.
|
||||||
|
"""
|
||||||
|
if not self._parsed or self.is_mutating():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# strip comments
|
||||||
|
kql = "\n".join(
|
||||||
|
line
|
||||||
|
for line in self._parsed.split("\n")
|
||||||
|
if not line.strip().startswith("//")
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
first_token = next(tokenize_kql(kql), None)
|
||||||
|
if not first_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return first_token == "|" or self._is_identifier(first_token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_identifier(identifier: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validates if a given string is a valid KQL identifier.
|
||||||
|
|
||||||
|
From the documentation:
|
||||||
|
|
||||||
|
Identifiers are case-sensitive. Database names are case-insensitive, and
|
||||||
|
therefore an exception to this rule.
|
||||||
|
Identifiers must be between 1 and 1024 characters long.
|
||||||
|
Identifiers may contain letters, digits, and underscores (_).
|
||||||
|
Identifiers may contain certain special characters: spaces, dots (.), and
|
||||||
|
dashes (-). For information on how to reference identifiers with special
|
||||||
|
characters, see Reference identifiers in queries.
|
||||||
|
|
||||||
|
"""
|
||||||
|
valid_chars = set(string.ascii_letters + string.digits + "_")
|
||||||
|
|
||||||
|
# Identifiers names that (1) include special character, (2) are language
|
||||||
|
# keywords, or (3) are literals must be enclosed using [' and '] or [" and "].
|
||||||
|
if (identifier.startswith("['") and identifier.endswith("']")) or (
|
||||||
|
identifier.startswith('["') and identifier.endswith('"]')
|
||||||
|
):
|
||||||
|
identifier = identifier[2:-2]
|
||||||
|
valid_chars.update(" .-")
|
||||||
|
|
||||||
|
return 1 <= len(identifier) <= 1024 and all(
|
||||||
|
char in valid_chars for char in identifier
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SQLScript:
|
class SQLScript:
|
||||||
"""
|
"""
|
||||||
@@ -642,6 +768,24 @@ class SQLScript:
|
|||||||
"""
|
"""
|
||||||
return any(statement.is_mutating() for statement in self.statements)
|
return any(statement.is_mutating() for statement in self.statements)
|
||||||
|
|
||||||
|
def is_valid_ctas(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the script contains a valid CTAS statement.
|
||||||
|
|
||||||
|
CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last
|
||||||
|
statement is a `SELECT`.
|
||||||
|
"""
|
||||||
|
return self.statements[-1].is_select()
|
||||||
|
|
||||||
|
def is_valid_cvas(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the script contains a valid CVAS statement.
|
||||||
|
|
||||||
|
CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single
|
||||||
|
`SELECT` statement.
|
||||||
|
"""
|
||||||
|
return len(self.statements) == 1 and self.statements[0].is_select()
|
||||||
|
|
||||||
|
|
||||||
def extract_tables_from_statement(
|
def extract_tables_from_statement(
|
||||||
statement: exp.Expression,
|
statement: exp.Expression,
|
||||||
@@ -650,7 +794,7 @@ def extract_tables_from_statement(
|
|||||||
"""
|
"""
|
||||||
Extract all table references in a single statement.
|
Extract all table references in a single statement.
|
||||||
|
|
||||||
Please not that this is not trivial; consider the following queries:
|
Please note that this is not trivial; consider the following queries:
|
||||||
|
|
||||||
DESCRIBE some_table;
|
DESCRIBE some_table;
|
||||||
SHOW PARTITIONS FROM some_table;
|
SHOW PARTITIONS FROM some_table;
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from superset import app
|
from superset import app
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.sql_parse import ParsedQuery
|
from superset.sql.parse import SQLScript, SQLStatement
|
||||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||||
from superset.utils.core import QuerySource
|
from superset.utils.core import QuerySource
|
||||||
|
|
||||||
@@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def validate_statement(
|
def validate_statement(
|
||||||
cls,
|
cls,
|
||||||
statement: str,
|
statement: SQLStatement,
|
||||||
database: Database,
|
database: Database,
|
||||||
cursor: Any,
|
cursor: Any,
|
||||||
) -> SQLValidationAnnotation | None:
|
) -> SQLValidationAnnotation | None:
|
||||||
# pylint: disable=too-many-locals
|
# pylint: disable=too-many-locals
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
|
|
||||||
sql = parsed_query.stripped()
|
|
||||||
|
|
||||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||||
sql = database.mutate_sql_based_on_config(sql)
|
sql = database.mutate_sql_based_on_config(str(statement))
|
||||||
|
|
||||||
# Transform the final statement to an explain call before sending it on
|
# Transform the final statement to an explain call before sending it on
|
||||||
# to presto to validate
|
# to presto to validate
|
||||||
@@ -155,10 +153,9 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||||||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||||
VALIDATE) SELECT 1 FROM default.mytable.
|
VALIDATE) SELECT 1 FROM default.mytable.
|
||||||
"""
|
"""
|
||||||
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
|
parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine)
|
||||||
statements = parsed_query.get_statements()
|
|
||||||
|
|
||||||
logger.info("Validating %i statement(s)", len(statements))
|
logger.info("Validating %i statement(s)", len(parsed_script.statements))
|
||||||
# todo(hughhh): update this to use new database.get_raw_connection()
|
# todo(hughhh): update this to use new database.get_raw_connection()
|
||||||
# this function keeps stalling CI
|
# this function keeps stalling CI
|
||||||
with database.get_sqla_engine(
|
with database.get_sqla_engine(
|
||||||
@@ -171,8 +168,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||||||
annotations: list[SQLValidationAnnotation] = []
|
annotations: list[SQLValidationAnnotation] = []
|
||||||
with closing(engine.raw_connection()) as conn:
|
with closing(engine.raw_connection()) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
for statement in parsed_query.get_statements():
|
for statement in parsed_script.statements:
|
||||||
annotation = cls.validate_statement(statement, database, cursor)
|
annotation = cls.validate_statement(
|
||||||
|
cast(SQLStatement, statement),
|
||||||
|
database,
|
||||||
|
cursor,
|
||||||
|
)
|
||||||
if annotation:
|
if annotation:
|
||||||
annotations.append(annotation)
|
annotations.append(annotation)
|
||||||
logger.debug("Validation found %i error(s)", len(annotations))
|
logger.debug("Validation found %i error(s)", len(annotations))
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from jinja2.meta import find_undeclared_variables
|
|||||||
from superset import is_feature_enabled
|
from superset import is_feature_enabled
|
||||||
from superset.commands.sql_lab.execute import SqlQueryRender
|
from superset.commands.sql_lab.execute import SqlQueryRender
|
||||||
from superset.errors import SupersetErrorType
|
from superset.errors import SupersetErrorType
|
||||||
from superset.sql_parse import ParsedQuery
|
|
||||||
from superset.sqllab.exceptions import SqlLabException
|
from superset.sqllab.exceptions import SqlLabException
|
||||||
from superset.utils import core as utils
|
from superset.utils import core as utils
|
||||||
|
|
||||||
@@ -58,12 +57,9 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
|||||||
database=query_model.database, query=query_model
|
database=query_model.database, query=query_model
|
||||||
)
|
)
|
||||||
|
|
||||||
parsed_query = ParsedQuery(
|
|
||||||
query_model.sql,
|
|
||||||
engine=query_model.database.db_engine_spec.engine,
|
|
||||||
)
|
|
||||||
rendered_query = sql_template_processor.process_template(
|
rendered_query = sql_template_processor.process_template(
|
||||||
parsed_query.stripped(), **execution_context.template_params
|
query_model.sql.strip().strip(";"),
|
||||||
|
**execution_context.template_params,
|
||||||
)
|
)
|
||||||
self._validate(execution_context, rendered_query, sql_template_processor)
|
self._validate(execution_context, rendered_query, sql_template_processor)
|
||||||
return rendered_query
|
return rendered_query
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from superset.db_engine_specs.base import (
|
|||||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from superset.utils.database import get_example_database
|
from superset.utils.database import get_example_database
|
||||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||||
from tests.integration_tests.test_app import app
|
from tests.integration_tests.test_app import app
|
||||||
@@ -310,20 +310,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_is_readonly():
|
|
||||||
def is_readonly(sql: str) -> bool:
|
|
||||||
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
|
|
||||||
|
|
||||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
|
||||||
assert not is_readonly("SET hivevar:desc='Legislators'")
|
|
||||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
|
||||||
assert is_readonly("EXPLAIN SELECT 1")
|
|
||||||
assert is_readonly("SELECT 1")
|
|
||||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
|
||||||
assert is_readonly("SHOW CATALOGS")
|
|
||||||
assert is_readonly("SHOW TABLES")
|
|
||||||
|
|
||||||
|
|
||||||
def test_time_grain_denylist():
|
def test_time_grain_denylist():
|
||||||
config = app.config.copy()
|
config = app.config.copy()
|
||||||
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]
|
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from sqlalchemy.sql import select
|
|||||||
|
|
||||||
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
|
||||||
from superset.exceptions import SupersetException
|
from superset.exceptions import SupersetException
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from tests.integration_tests.test_app import app
|
from tests.integration_tests.test_app import app
|
||||||
|
|
||||||
|
|
||||||
@@ -222,19 +222,6 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
|
|||||||
app.config = config
|
app.config = config
|
||||||
|
|
||||||
|
|
||||||
def test_is_readonly():
|
|
||||||
def is_readonly(sql: str) -> bool:
|
|
||||||
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))
|
|
||||||
|
|
||||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
|
||||||
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
|
|
||||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
|
||||||
assert is_readonly("SET hivevar:desc='Legislators'")
|
|
||||||
assert is_readonly("EXPLAIN SELECT 1")
|
|
||||||
assert is_readonly("SELECT 1")
|
|
||||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"schema,upload_prefix",
|
"schema,upload_prefix",
|
||||||
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],
|
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from sqlalchemy.sql import select
|
|||||||
|
|
||||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.sql_parse import ParsedQuery, Table
|
from superset.sql_parse import Table
|
||||||
from superset.utils.database import get_example_database
|
from superset.utils.database import get_example_database
|
||||||
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||||
|
|
||||||
@@ -1172,19 +1172,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_is_readonly():
|
|
||||||
def is_readonly(sql: str) -> bool:
|
|
||||||
return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
|
|
||||||
|
|
||||||
assert not is_readonly("SET hivevar:desc='Legislators'")
|
|
||||||
assert not is_readonly("UPDATE t1 SET col1 = NULL")
|
|
||||||
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
|
|
||||||
assert is_readonly("SHOW LOCKS test EXTENDED")
|
|
||||||
assert is_readonly("EXPLAIN SELECT 1")
|
|
||||||
assert is_readonly("SELECT 1")
|
|
||||||
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_catalog_names(app_context: AppContext) -> None:
|
def test_get_catalog_names(app_context: AppContext) -> None:
|
||||||
"""
|
"""
|
||||||
Test the ``get_catalog_names`` method.
|
Test the ``get_catalog_names`` method.
|
||||||
|
|||||||
@@ -49,32 +49,6 @@ def test_get_text_clause_with_colon() -> None:
|
|||||||
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
|
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
|
||||||
|
|
||||||
|
|
||||||
def test_parse_sql_single_statement() -> None:
|
|
||||||
"""
|
|
||||||
`parse_sql` should properly strip leading and trailing spaces and semicolons
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec
|
|
||||||
|
|
||||||
queries = BaseEngineSpec.parse_sql(" SELECT foo FROM tbl ; ")
|
|
||||||
assert queries == ["SELECT foo FROM tbl"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_sql_multi_statement() -> None:
|
|
||||||
"""
|
|
||||||
For string with multiple SQL-statements `parse_sql` method should return list
|
|
||||||
where each element represents the single SQL-statement
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec
|
|
||||||
|
|
||||||
queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;")
|
|
||||||
assert queries == [
|
|
||||||
"SELECT foo FROM tbl1",
|
|
||||||
"SELECT bar FROM tbl2",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_db_uri(mocker: MockerFixture) -> None:
|
def test_validate_db_uri(mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
Ensures that the `validate_database_uri` method invokes the validator correctly
|
Ensures that the `validate_database_uri` method invokes the validator correctly
|
||||||
|
|||||||
@@ -24,91 +24,6 @@ from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
|
|||||||
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
from tests.unit_tests.fixtures.common import dttm # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sql,expected",
|
|
||||||
[
|
|
||||||
("SELECT foo FROM tbl", True),
|
|
||||||
("SHOW TABLES", False),
|
|
||||||
("EXPLAIN SELECT foo FROM tbl", False),
|
|
||||||
("INSERT INTO tbl (foo) VALUES (1)", False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_sql_is_readonly_query(sql: str, expected: bool) -> None:
|
|
||||||
"""
|
|
||||||
Make sure that SQL dialect consider only SELECT statements as read-only
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
|
|
||||||
from superset.sql_parse import ParsedQuery
|
|
||||||
|
|
||||||
parsed_query = ParsedQuery(sql)
|
|
||||||
is_readonly = KustoSqlEngineSpec.is_readonly_query(parsed_query)
|
|
||||||
|
|
||||||
assert expected == is_readonly
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"kql,expected",
|
|
||||||
[
|
|
||||||
("tbl | limit 100", True),
|
|
||||||
("let foo = 1; tbl | where bar == foo", True),
|
|
||||||
(".show tables", False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_kql_is_select_query(kql: str, expected: bool) -> None:
|
|
||||||
"""
|
|
||||||
Make sure that KQL dialect consider only statements that do not start with "." (dot)
|
|
||||||
as a SELECT statements
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
|
||||||
from superset.sql_parse import ParsedQuery
|
|
||||||
|
|
||||||
parsed_query = ParsedQuery(kql)
|
|
||||||
is_select = KustoKqlEngineSpec.is_select_query(parsed_query)
|
|
||||||
|
|
||||||
assert expected == is_select
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"kql,expected",
|
|
||||||
[
|
|
||||||
("tbl | limit 100", True),
|
|
||||||
("let foo = 1; tbl | where bar == foo", True),
|
|
||||||
(".show tables", True),
|
|
||||||
("print 1", True),
|
|
||||||
("set querytrace; Events | take 100", True),
|
|
||||||
(".drop table foo", False),
|
|
||||||
(".set-or-append table foo <| bar", False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_kql_is_readonly_query(kql: str, expected: bool) -> None:
|
|
||||||
"""
|
|
||||||
Make sure that KQL dialect consider only SELECT statements as read-only
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
|
||||||
from superset.sql_parse import ParsedQuery
|
|
||||||
|
|
||||||
parsed_query = ParsedQuery(kql)
|
|
||||||
is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query)
|
|
||||||
|
|
||||||
assert expected == is_readonly
|
|
||||||
|
|
||||||
|
|
||||||
def test_kql_parse_sql() -> None:
|
|
||||||
"""
|
|
||||||
parse_sql method should always return a list with a single element
|
|
||||||
which is an original query
|
|
||||||
"""
|
|
||||||
|
|
||||||
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
|
|
||||||
|
|
||||||
queries = KustoKqlEngineSpec.parse_sql("let foo = 1; tbl | where bar == foo")
|
|
||||||
|
|
||||||
assert queries == ["let foo = 1; tbl | where bar == foo"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"target_type,expected_result",
|
"target_type,expected_result",
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -945,6 +945,32 @@ on $left.Day1 == $right.Day
|
|||||||
("kustokql", "set querytrace; Events | take 100", False),
|
("kustokql", "set querytrace; Events | take 100", False),
|
||||||
("kustokql", ".drop table foo", True),
|
("kustokql", ".drop table foo", True),
|
||||||
("kustokql", ".set-or-append table foo <| bar", True),
|
("kustokql", ".set-or-append table foo <| bar", True),
|
||||||
|
("kustosql", "SELECT foo FROM tbl", False),
|
||||||
|
("kustosql", "SHOW TABLES", False),
|
||||||
|
("kustosql", "EXPLAIN SELECT foo FROM tbl", False),
|
||||||
|
("kustosql", "INSERT INTO tbl (foo) VALUES (1)", True),
|
||||||
|
("base", "SHOW LOCKS test EXTENDED", False),
|
||||||
|
("base", "SET hivevar:desc='Legislators'", False),
|
||||||
|
("base", "UPDATE t1 SET col1 = NULL", True),
|
||||||
|
("base", "EXPLAIN SELECT 1", False),
|
||||||
|
("base", "SELECT 1", False),
|
||||||
|
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||||
|
("base", "SHOW CATALOGS", False),
|
||||||
|
("base", "SHOW TABLES", False),
|
||||||
|
("hive", "UPDATE t1 SET col1 = NULL", True),
|
||||||
|
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
||||||
|
("hive", "SHOW LOCKS test EXTENDED", False),
|
||||||
|
("hive", "SET hivevar:desc='Legislators'", False),
|
||||||
|
("hive", "EXPLAIN SELECT 1", False),
|
||||||
|
("hive", "SELECT 1", False),
|
||||||
|
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||||
|
("presto", "SET hivevar:desc='Legislators'", False),
|
||||||
|
("presto", "UPDATE t1 SET col1 = NULL", True),
|
||||||
|
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
|
||||||
|
("presto", "SHOW LOCKS test EXTENDED", False),
|
||||||
|
("presto", "EXPLAIN SELECT 1", False),
|
||||||
|
("presto", "SELECT 1", False),
|
||||||
|
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
|
||||||
@@ -1042,9 +1068,106 @@ def test_custom_dialect(app: None) -> None:
|
|||||||
)
|
)
|
||||||
def test_is_mutating(engine: str) -> None:
|
def test_is_mutating(engine: str) -> None:
|
||||||
"""
|
"""
|
||||||
Tests for `is_mutating`.
|
Global tests for `is_mutating`, covering all supported engines.
|
||||||
"""
|
"""
|
||||||
assert not SQLStatement(
|
assert not SQLStatement(
|
||||||
"with source as ( select 1 as one ) select * from source",
|
"with source as ( select 1 as one ) select * from source",
|
||||||
engine=engine,
|
engine=engine,
|
||||||
).is_mutating()
|
).is_mutating()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"identifier,expected",
|
||||||
|
[
|
||||||
|
# Rule: Identifiers are case-sensitive
|
||||||
|
("myTable", True),
|
||||||
|
("MYTABLE", True),
|
||||||
|
("MyTable", True),
|
||||||
|
# Rule: Identifiers must be between 1 and 1024 characters long
|
||||||
|
("a", True),
|
||||||
|
("a" * 1024, True),
|
||||||
|
("a" * 1025, False),
|
||||||
|
("", False),
|
||||||
|
# Rule: Identifiers may contain letters, digits, and underscores
|
||||||
|
("My_Table_123", True),
|
||||||
|
("123Table", True),
|
||||||
|
# Rule: Identifiers may contain special characters: spaces, dots, dashes (when quoted)
|
||||||
|
("['My Table']", True),
|
||||||
|
("['My-Table']", True),
|
||||||
|
("['My.Table']", True),
|
||||||
|
("['Table-']", True),
|
||||||
|
("['My Table Name']", True),
|
||||||
|
("['My!Table']", False),
|
||||||
|
("['MyTable ']", True),
|
||||||
|
(" MyTable", False),
|
||||||
|
("MyTable ", False),
|
||||||
|
# Rule: Non-special identifiers don't require quoting
|
||||||
|
("MyTable", True),
|
||||||
|
("My-Table", False),
|
||||||
|
# Invalid quoting
|
||||||
|
("['Invalid]", False),
|
||||||
|
("['Invalid'Name']", False),
|
||||||
|
("['']", False),
|
||||||
|
# Rule: Literal identifiers or language keywords
|
||||||
|
("['select']", True),
|
||||||
|
("select", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_kql_identifier(identifier: str, expected: bool):
|
||||||
|
"""
|
||||||
|
Tests the _is_identifier method for various valid and invalid cases.
|
||||||
|
"""
|
||||||
|
assert KustoKQLStatement._is_identifier(identifier) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kql,expected",
|
||||||
|
[
|
||||||
|
# Simple SELECT-like statements (non-mutating queries)
|
||||||
|
("MyTable | count", True),
|
||||||
|
("MyTable", True),
|
||||||
|
("| count", True),
|
||||||
|
("tbl | limit 100", True),
|
||||||
|
(".show tables", False),
|
||||||
|
# With comments (ensure comments are stripped out)
|
||||||
|
("// Comment only", False),
|
||||||
|
("MyTable // trailing comment", True),
|
||||||
|
("// leading comment\nMyTable", True),
|
||||||
|
("MyTable\n// intermediate comment\n| count", True),
|
||||||
|
# Mutating query (should return False)
|
||||||
|
(".drop MyTable", False),
|
||||||
|
(
|
||||||
|
".update MyTable set Column1 = 100",
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
(".alter MyTable", False),
|
||||||
|
# Edge cases for first token
|
||||||
|
("", False),
|
||||||
|
(" ", False),
|
||||||
|
(".command", False),
|
||||||
|
("['My Table']", True),
|
||||||
|
# Complex multi-line queries
|
||||||
|
(
|
||||||
|
"""
|
||||||
|
// Initial comment
|
||||||
|
MyTable
|
||||||
|
| where Column1 > 100
|
||||||
|
""",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"""
|
||||||
|
MyTable
|
||||||
|
| where Column1 > 100
|
||||||
|
| summarize by Column2
|
||||||
|
// Final comment
|
||||||
|
""",
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_kql_is_select(kql: str, expected: bool):
|
||||||
|
"""
|
||||||
|
Tests the is_select method for various valid and invalid cases.
|
||||||
|
"""
|
||||||
|
assert KustoKQLStatement(kql).is_select() == expected
|
||||||
|
|||||||
Reference in New Issue
Block a user