Compare commits

...

5 Commits

Author SHA1 Message Date
Beto Dealmeida
754807b87d WIP 2024-11-25 16:43:20 -05:00
Beto Dealmeida
782f94fe8d Another fix 2024-11-24 18:53:14 -05:00
Beto Dealmeida
8e0c00a82e Move integration tests to unit tests 2024-11-24 18:44:51 -05:00
Beto Dealmeida
c0c8802de9 Small fix 2024-11-24 18:03:20 -05:00
Beto Dealmeida
cd3209a600 chore (SIP-117): remove more sqlparse 2024-11-22 12:04:19 -05:00
18 changed files with 359 additions and 315 deletions

View File

@@ -36,7 +36,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException
from superset.extensions import db
from superset.models.core import Database
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
table.sql = self._base_model.sql.strip().strip(";")
db.session.add(table)
cols = []
for config_ in self._base_model.columns:

View File

@@ -1778,7 +1778,7 @@ GUEST_TOKEN_VALIDATOR_HOOK = None
# def DATASET_HEALTH_CHECK(datasource: SqlaTable) -> Optional[str]:
# if (
# datasource.sql and
# len(sql_parse.ParsedQuery(datasource.sql, strip_comments=True).tables) == 1
# len(SQLScript(datasource.sql).tables) == 1
# ):
# return (
# "This virtual dataset queries only one table and therefore could be "

View File

@@ -67,7 +67,7 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.expression import Label
from sqlalchemy.sql.selectable import Alias, TableClause
from superset import app, db, is_feature_enabled, security_manager
@@ -104,7 +104,7 @@ from superset.models.helpers import (
QueryResult,
)
from superset.models.slice import Slice
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import (
AdhocColumn,
AdhocMetric,
@@ -1469,34 +1469,13 @@ class SqlaTable(
return tbl
def get_from_clause(
self, template_processor: BaseTemplateProcessor | None = None
self,
template_processor: BaseTemplateProcessor | None = None,
) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(self.db_engine_spec.cte_alias)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)
return from_clause, cte
return super().get_from_clause(template_processor)
def adhoc_metric_to_sqla(
self,

View File

@@ -63,7 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql.parse import SQLScript, Table
from superset.sql.parse import BaseSQLStatement, SQLScript, Table
from superset.sql_parse import ParsedQuery
from superset.superset_typing import (
OAuth2ClientConfig,
@@ -1737,18 +1737,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
@classmethod
def process_statement(cls, statement: str, database: Database) -> str:
def process_statement(
cls,
statement: BaseSQLStatement[Any],
database: Database,
) -> str:
"""
Process a SQL statement by stripping and mutating it.
Process a SQL statement by mutating it.
:param statement: A single SQL statement
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
return database.mutate_sql_based_on_config(sql, is_split=True)
return database.mutate_sql_based_on_config(str(statement), is_split=True)
@classmethod
def estimate_query_cost( # pylint: disable=too-many-arguments
@@ -1773,8 +1774,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"Database does not support cost estimation"
)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)
with database.get_raw_connection(
catalog=catalog,
@@ -1788,7 +1788,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls.process_statement(statement, database),
cursor,
)
for statement in statements
for statement in parsed_script.statements
]
@classmethod
@@ -2056,15 +2056,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
logger.error(ex, exc_info=True)
raise
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
parsed_query.is_select()
or parsed_query.is_explain()
or parsed_query.is_show()
)
@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
"""
@@ -2178,10 +2169,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return False
@classmethod
def parse_sql(cls, sql: str) -> list[str]:
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@classmethod
def get_impersonation_key(cls, user: User | None) -> Any:
"""

View File

@@ -36,7 +36,6 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes
from superset import sql_parse
from superset.constants import TimeGrain
from superset.databases.schemas import encrypted_field_properties, EncryptedString
from superset.databases.utils import make_url_safe
@@ -44,6 +43,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils, json
@@ -449,8 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=cls.engine)
with cls.get_engine(
database,
@@ -463,7 +462,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls.process_statement(statement, database),
client,
)
for statement in statements
for statement in parsed_script.statements
]
@classmethod

View File

@@ -45,7 +45,7 @@ from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType
if TYPE_CHECKING:
@@ -605,15 +605,6 @@ class HiveEngineSpec(PrestoEngineSpec):
# otherwise, return no function names to prevent errors
return []
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return (
super().is_readonly_query(parsed_query)
or parsed_query.is_set()
or parsed_query.is_show()
)
@classmethod
def has_implicit_cancel(cls) -> bool:
"""

View File

@@ -104,11 +104,6 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return f"""CONVERT(DATETIME, '{datetime_formatted}', 126)"""
return None
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return parsed_query.sql.lower().startswith("select")
class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
limit_method = LimitMethod.WRAP_SQL
@@ -158,23 +153,6 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return None
@classmethod
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""
Pessimistic readonly, 100% sure statement won't mutate anything.
"""
return KustoKqlEngineSpec.is_select_query(
parsed_query
) or parsed_query.sql.startswith(".show")
@classmethod
def is_select_query(cls, parsed_query: ParsedQuery) -> bool:
return not parsed_query.sql.startswith(".")
@classmethod
def parse_sql(cls, sql: str) -> list[str]:
"""
Kusto supports a single query statement, but it could include sub queries
and variables declared via let keyword.
"""
return [sql]

View File

@@ -74,6 +74,7 @@ from superset.extensions import (
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
from superset.sql.parse import SQLScript
from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
from superset.utils import cache as cache_util, core as utils, json
@@ -674,7 +675,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
schema: str | None = None,
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
sqls = self.db_engine_spec.parse_sql(sql)
parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine)
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
engine_url = engine.url
@@ -691,8 +692,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
with self.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
df = None
for i, sql_ in enumerate(sqls):
sql_ = self.mutate_sql_based_on_config(sql_, is_split=True)
for i, statement in enumerate(parsed_script.statements):
# pylint: disable=protected-access
sql_ = self.mutate_sql_based_on_config(statement._sql, is_split=True)
_log_query(sql_)
with event_logger.log_context(
action="execute_sql",
@@ -700,7 +702,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
object_ref=__name__,
):
self.db_engine_spec.execute(cursor, sql_, self)
if i < len(sqls) - 1:
if i < len(parsed_script.statements) - 1:
# If it's not the last, we don't keep the results
cursor.fetchall()
else:

View File

@@ -72,7 +72,6 @@ from superset.sql.parse import SQLScript
from superset.sql_parse import (
has_table_query,
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
)
from superset.superset_typing import (
@@ -1039,6 +1038,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""
Render sql with template engine (Jinja).
"""
if not self.sql:
return ""
sql = self.sql.strip("\t\r\n; ")
if template_processor:
try:
@@ -1072,13 +1074,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine)
if parsed_script.has_mutation():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)

View File

@@ -20,10 +20,11 @@ from __future__ import annotations
import enum
import logging
import re
import string
import urllib.parse
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Iterator, TypeVar
import sqlglot
import sqlparse
@@ -226,6 +227,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
@@ -382,6 +389,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
return False
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
return isinstance(self._parsed, exp.Select)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
@@ -431,60 +444,115 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
}
class KQLSplitState(enum.Enum):
class KQLTokenizeState(enum.Enum):
"""
State machine for splitting a KQL script.
State machine for tokenizing a KQL script.
The state machine keeps track of whether we're inside a string or not, so we
don't split the script in a semi-colon that's part of a string.
"""
OUTSIDE_STRING = enum.auto()
OUTSIDE = enum.auto()
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
INSIDE_MULTILINE_STRING = enum.auto()
INSIDE_SINGLE_QUOTED_IDENTIFIER = enum.auto()
INSIDE_DOUBLE_QUOTED_IDENTIFIER = enum.auto()
def tokenize_kql(kql: str) -> Iterator[str]:
"""
Tokenize a KQL script.
"""
valid_identifier_chars = set(string.ascii_letters + string.digits + "_")
valid_quoted_identifier_chars = valid_identifier_chars | set(" .-")
script = kql if kql.endswith(";") else kql + ";"
cursor = 0
while cursor < len(script):
rest = script[cursor:]
# quoted identifiers
if rest[:2] in {"['", '["'}:
match = "']" if rest[:2] == "['" else '"]'
if match not in rest[2:]:
raise SupersetParseError(
script,
"kustokql",
message="Unclosed quoted identifier",
)
token = rest[: rest.index(match, 2) + 2]
if any(char not in valid_quoted_identifier_chars for char in token[2:-2]):
raise SupersetParseError(
script,
"kustokql",
message="Invalid quoted identifier",
)
yield token
cursor += len(token)
# multi-line strings
elif rest[:3] == "```":
if "```" not in rest[3:]:
raise SupersetParseError(
script,
"kustokql",
message="Unclosed multi-line string",
)
token = rest[: rest.index("```", 3) + 3]
yield token
cursor += len(token)
# single-quoted strings
elif rest[0] in {'"', "'"}:
match = rest[0]
# find first unescaped quote
start = 1
while True:
if match not in rest[start:]:
raise SupersetParseError(
script,
"kustokql",
message="Unclosed string",
)
index = rest.index(match, start)
if rest[index - 1] != "\\":
break
start = index + 1
token = rest[: index + 1]
yield token
cursor += len(token)
# identifiers and keywords
else:
for i, char in enumerate(rest):
if char not in valid_identifier_chars:
if i > 0:
yield rest[:i]
yield char
cursor += i + 1
break
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
script = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(script):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(script[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and script[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and script[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
statements: list[str] = []
statement: list[str] = []
for token in tokenize_kql(kql):
if token == ";":
statements.append("".join(statement))
statement = []
else:
statement.append(token)
return statements
@@ -506,6 +574,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
details about it.
"""
def __init__(
self,
statement: str,
engine: str = "kustokql",
ast: str | None = None,
):
super().__init__(statement, engine, ast)
@classmethod
def split_script(
cls,
@@ -588,6 +664,56 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
if not self._parsed or self.is_mutating():
return False
# strip comments
kql = "\n".join(
line
for line in self._parsed.split("\n")
if not line.strip().startswith("//")
).strip()
first_token = next(tokenize_kql(kql), None)
if not first_token:
return False
return first_token == "|" or self._is_identifier(first_token)
@staticmethod
def _is_identifier(identifier: str) -> bool:
"""
Validates if a given string is a valid KQL identifier.
From the documentation:
Identifiers are case-sensitive. Database names are case-insensitive, and
therefore an exception to this rule.
Identifiers must be between 1 and 1024 characters long.
Identifiers may contain letters, digits, and underscores (_).
Identifiers may contain certain special characters: spaces, dots (.), and
dashes (-). For information on how to reference identifiers with special
characters, see Reference identifiers in queries.
"""
valid_chars = set(string.ascii_letters + string.digits + "_")
# Identifiers names that (1) include special character, (2) are language
# keywords, or (3) are literals must be enclosed using [' and '] or [" and "].
if (identifier.startswith("['") and identifier.endswith("']")) or (
identifier.startswith('["') and identifier.endswith('"]')
):
identifier = identifier[2:-2]
valid_chars.update(" .-")
return 1 <= len(identifier) <= 1024 and all(
char in valid_chars for char in identifier
)
class SQLScript:
"""
@@ -642,6 +768,24 @@ class SQLScript:
"""
return any(statement.is_mutating() for statement in self.statements)
def is_valid_ctas(self) -> bool:
"""
Check if the script contains a valid CTAS statement.
CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last
statement is a `SELECT`.
"""
return self.statements[-1].is_select()
def is_valid_cvas(self) -> bool:
"""
Check if the script contains a valid CVAS statement.
CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single
`SELECT` statement.
"""
return len(self.statements) == 1 and self.statements[0].is_select()
def extract_tables_from_statement(
statement: exp.Expression,
@@ -650,7 +794,7 @@ def extract_tables_from_statement(
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
Please note that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;

View File

@@ -20,11 +20,11 @@ from __future__ import annotations
import logging
import time
from contextlib import closing
from typing import Any
from typing import Any, cast
from superset import app
from superset.models.core import Database
from superset.sql_parse import ParsedQuery
from superset.sql.parse import SQLScript, SQLStatement
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
from superset.utils.core import QuerySource
@@ -46,17 +46,15 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate_statement(
cls,
statement: str,
statement: SQLStatement,
database: Database,
cursor: Any,
) -> SQLValidationAnnotation | None:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql = database.mutate_sql_based_on_config(sql)
sql = database.mutate_sql_based_on_config(str(statement))
# Transform the final statement to an explain call before sending it on
# to presto to validate
@@ -155,10 +153,9 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine)
logger.info("Validating %i statement(s)", len(statements))
logger.info("Validating %i statement(s)", len(parsed_script.statements))
# todo(hughhh): update this to use new database.get_raw_connection()
# this function keeps stalling CI
with database.get_sqla_engine(
@@ -171,8 +168,12 @@ class PrestoDBSQLValidator(BaseSQLValidator):
annotations: list[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
annotation = cls.validate_statement(statement, database, cursor)
for statement in parsed_script.statements:
annotation = cls.validate_statement(
cast(SQLStatement, statement),
database,
cursor,
)
if annotation:
annotations.append(annotation)
logger.debug("Validation found %i error(s)", len(annotations))

View File

@@ -26,7 +26,6 @@ from jinja2.meta import find_undeclared_variables
from superset import is_feature_enabled
from superset.commands.sql_lab.execute import SqlQueryRender
from superset.errors import SupersetErrorType
from superset.sql_parse import ParsedQuery
from superset.sqllab.exceptions import SqlLabException
from superset.utils import core as utils
@@ -58,12 +57,9 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model
)
parsed_query = ParsedQuery(
query_model.sql,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
query_model.sql.strip().strip(";"),
**execution_context.template_params,
)
self._validate(execution_context, rendered_query, sql_template_processor)
return rendered_query

View File

@@ -30,7 +30,7 @@ from superset.db_engine_specs.base import (
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app
@@ -310,20 +310,6 @@ class TestDbEngineSpecs(TestDbEngineSpec):
)
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return BaseEngineSpec.is_readonly_query(ParsedQuery(sql))
assert is_readonly("SHOW LOCKS test EXTENDED")
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
assert is_readonly("SHOW CATALOGS")
assert is_readonly("SHOW TABLES")
def test_time_grain_denylist():
config = app.config.copy()
app.config["TIME_GRAIN_DENYLIST"] = ["PT1M", "SQLITE_NONEXISTENT_GRAIN"]

View File

@@ -23,7 +23,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.hive import HiveEngineSpec, upload_to_s3
from superset.exceptions import SupersetException
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from tests.integration_tests.test_app import app
@@ -222,19 +222,6 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g):
app.config = config
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return HiveEngineSpec.is_readonly_query(ParsedQuery(sql))
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
assert is_readonly("SHOW LOCKS test EXTENDED")
assert is_readonly("SET hivevar:desc='Legislators'")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
@pytest.mark.parametrize(
"schema,upload_prefix",
[("foo", "EXTERNAL_HIVE_TABLES/1/foo/"), (None, "EXTERNAL_HIVE_TABLES/1/")],

View File

@@ -25,7 +25,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import Table
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -1172,19 +1172,6 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
]
def test_is_readonly():
def is_readonly(sql: str) -> bool:
return PrestoEngineSpec.is_readonly_query(ParsedQuery(sql))
assert not is_readonly("SET hivevar:desc='Legislators'")
assert not is_readonly("UPDATE t1 SET col1 = NULL")
assert not is_readonly("INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA")
assert is_readonly("SHOW LOCKS test EXTENDED")
assert is_readonly("EXPLAIN SELECT 1")
assert is_readonly("SELECT 1")
assert is_readonly("WITH (SELECT 1) bla SELECT * from bla")
def test_get_catalog_names(app_context: AppContext) -> None:
"""
Test the ``get_catalog_names`` method.

View File

@@ -49,32 +49,6 @@ def test_get_text_clause_with_colon() -> None:
assert text_clause.text == "SELECT foo FROM tbl WHERE foo = '123\\:456')"
def test_parse_sql_single_statement() -> None:
"""
`parse_sql` should properly strip leading and trailing spaces and semicolons
"""
from superset.db_engine_specs.base import BaseEngineSpec
queries = BaseEngineSpec.parse_sql(" SELECT foo FROM tbl ; ")
assert queries == ["SELECT foo FROM tbl"]
def test_parse_sql_multi_statement() -> None:
"""
For string with multiple SQL-statements `parse_sql` method should return list
where each element represents the single SQL-statement
"""
from superset.db_engine_specs.base import BaseEngineSpec
queries = BaseEngineSpec.parse_sql("SELECT foo FROM tbl1; SELECT bar FROM tbl2;")
assert queries == [
"SELECT foo FROM tbl1",
"SELECT bar FROM tbl2",
]
def test_validate_db_uri(mocker: MockerFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly

View File

@@ -24,91 +24,6 @@ from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm # noqa: F401
@pytest.mark.parametrize(
"sql,expected",
[
("SELECT foo FROM tbl", True),
("SHOW TABLES", False),
("EXPLAIN SELECT foo FROM tbl", False),
("INSERT INTO tbl (foo) VALUES (1)", False),
],
)
def test_sql_is_readonly_query(sql: str, expected: bool) -> None:
"""
Make sure that SQL dialect consider only SELECT statements as read-only
"""
from superset.db_engine_specs.kusto import KustoSqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(sql)
is_readonly = KustoSqlEngineSpec.is_readonly_query(parsed_query)
assert expected == is_readonly
@pytest.mark.parametrize(
"kql,expected",
[
("tbl | limit 100", True),
("let foo = 1; tbl | where bar == foo", True),
(".show tables", False),
],
)
def test_kql_is_select_query(kql: str, expected: bool) -> None:
"""
Make sure that KQL dialect consider only statements that do not start with "." (dot)
as a SELECT statements
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(kql)
is_select = KustoKqlEngineSpec.is_select_query(parsed_query)
assert expected == is_select
@pytest.mark.parametrize(
"kql,expected",
[
("tbl | limit 100", True),
("let foo = 1; tbl | where bar == foo", True),
(".show tables", True),
("print 1", True),
("set querytrace; Events | take 100", True),
(".drop table foo", False),
(".set-or-append table foo <| bar", False),
],
)
def test_kql_is_readonly_query(kql: str, expected: bool) -> None:
"""
Make sure that KQL dialect consider only SELECT statements as read-only
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
from superset.sql_parse import ParsedQuery
parsed_query = ParsedQuery(kql)
is_readonly = KustoKqlEngineSpec.is_readonly_query(parsed_query)
assert expected == is_readonly
def test_kql_parse_sql() -> None:
"""
parse_sql method should always return a list with a single element
which is an original query
"""
from superset.db_engine_specs.kusto import KustoKqlEngineSpec
queries = KustoKqlEngineSpec.parse_sql("let foo = 1; tbl | where bar == foo")
assert queries == ["let foo = 1; tbl | where bar == foo"]
@pytest.mark.parametrize(
"target_type,expected_result",
[

View File

@@ -945,6 +945,32 @@ on $left.Day1 == $right.Day
("kustokql", "set querytrace; Events | take 100", False),
("kustokql", ".drop table foo", True),
("kustokql", ".set-or-append table foo <| bar", True),
("kustosql", "SELECT foo FROM tbl", False),
("kustosql", "SHOW TABLES", False),
("kustosql", "EXPLAIN SELECT foo FROM tbl", False),
("kustosql", "INSERT INTO tbl (foo) VALUES (1)", True),
("base", "SHOW LOCKS test EXTENDED", False),
("base", "SET hivevar:desc='Legislators'", False),
("base", "UPDATE t1 SET col1 = NULL", True),
("base", "EXPLAIN SELECT 1", False),
("base", "SELECT 1", False),
("base", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("base", "SHOW CATALOGS", False),
("base", "SHOW TABLES", False),
("hive", "UPDATE t1 SET col1 = NULL", True),
("hive", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("hive", "SHOW LOCKS test EXTENDED", False),
("hive", "SET hivevar:desc='Legislators'", False),
("hive", "EXPLAIN SELECT 1", False),
("hive", "SELECT 1", False),
("hive", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
("presto", "SET hivevar:desc='Legislators'", False),
("presto", "UPDATE t1 SET col1 = NULL", True),
("presto", "INSERT OVERWRITE TABLE tabB SELECT a.Age FROM TableA", True),
("presto", "SHOW LOCKS test EXTENDED", False),
("presto", "EXPLAIN SELECT 1", False),
("presto", "SELECT 1", False),
("presto", "WITH bla AS (SELECT 1) SELECT * FROM bla", False),
],
)
def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
@@ -1042,9 +1068,106 @@ def test_custom_dialect(app: None) -> None:
)
def test_is_mutating(engine: str) -> None:
"""
Tests for `is_mutating`.
Global tests for `is_mutating`, covering all supported engines.
"""
assert not SQLStatement(
"with source as ( select 1 as one ) select * from source",
engine=engine,
).is_mutating()
@pytest.mark.parametrize(
"identifier,expected",
[
# Rule: Identifiers are case-sensitive
("myTable", True),
("MYTABLE", True),
("MyTable", True),
# Rule: Identifiers must be between 1 and 1024 characters long
("a", True),
("a" * 1024, True),
("a" * 1025, False),
("", False),
# Rule: Identifiers may contain letters, digits, and underscores
("My_Table_123", True),
("123Table", True),
# Rule: Identifiers may contain special characters: spaces, dots, dashes (when quoted)
("['My Table']", True),
("['My-Table']", True),
("['My.Table']", True),
("['Table-']", True),
("['My Table Name']", True),
("['My!Table']", False),
("['MyTable ']", True),
(" MyTable", False),
("MyTable ", False),
# Rule: Non-special identifiers don't require quoting
("MyTable", True),
("My-Table", False),
# Invalid quoting
("['Invalid]", False),
("['Invalid'Name']", False),
("['']", False),
# Rule: Literal identifiers or language keywords
("['select']", True),
("select", True),
],
)
def test_is_kql_identifier(identifier: str, expected: bool):
"""
Tests the _is_identifier method for various valid and invalid cases.
"""
assert KustoKQLStatement._is_identifier(identifier) == expected
@pytest.mark.parametrize(
"kql,expected",
[
# Simple SELECT-like statements (non-mutating queries)
("MyTable | count", True),
("MyTable", True),
("| count", True),
("tbl | limit 100", True),
(".show tables", False),
# With comments (ensure comments are stripped out)
("// Comment only", False),
("MyTable // trailing comment", True),
("// leading comment\nMyTable", True),
("MyTable\n// intermediate comment\n| count", True),
# Mutating query (should return False)
(".drop MyTable", False),
(
".update MyTable set Column1 = 100",
False,
),
(".alter MyTable", False),
# Edge cases for first token
("", False),
(" ", False),
(".command", False),
("['My Table']", True),
# Complex multi-line queries
(
"""
// Initial comment
MyTable
| where Column1 > 100
""",
True,
),
(
"""
MyTable
| where Column1 > 100
| summarize by Column2
// Final comment
""",
True,
),
],
)
def test_kql_is_select(kql: str, expected: bool):
"""
Tests the is_select method for various valid and invalid cases.
"""
assert KustoKQLStatement(kql).is_select() == expected