chore: organize SQL parsing files (#30258)

This commit is contained in:
Beto Dealmeida
2024-09-13 16:24:19 -04:00
committed by GitHub
parent 8cd18cac8c
commit bdf29cb7c2
13 changed files with 1650 additions and 886 deletions

View File

@@ -19,23 +19,16 @@
from __future__ import annotations
import enum
import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
from collections.abc import Iterator
from typing import Any, cast, TYPE_CHECKING
import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlglot.dialects.dialect import Dialects
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@@ -68,6 +61,7 @@ from superset.exceptions import (
SupersetParseError,
SupersetSecurityException,
)
from superset.sql.parse import extract_tables_from_statement, SQLScript, Table
from superset.utils.backports import StrEnum
try:
@@ -226,7 +220,9 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
sql: str,
function_list: set[str],
engine: str = "base",
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
@@ -238,7 +234,7 @@ def check_sql_functions_exist(
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
def strip_comments_from_sql(statement: str, engine: str = "base") -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@@ -255,554 +251,18 @@ def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
)
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, __o: object) -> bool:
return str(self) == str(__o)
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def is_cte(source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
# store the AST as a string (the original query), and manipulate it with regular
# expressions.
InternalRepresentation = TypeVar("InternalRepresentation")
# The base type. This helps type checking the `split_query` method correctly, since each
# derived class has a more specific return type (the class itself). This will no longer
# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
# information: https://peps.python.org/pep-0673/
TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name
class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the query or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a query in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""
def __init__(
self,
statement: str | InternalRepresentation,
engine: str,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@classmethod
def split_query(
cls: type[TBaseSQLStatement],
query: str,
engine: str,
) -> list[TBaseSQLStatement]:
"""
Split a query into multiple instantiated statements.
This is a helper function to split a full SQL query into multiple
`BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
statements within a query.
"""
raise NotImplementedError()
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> InternalRepresentation:
"""
Parse a string containing a single SQL statement, and returns the parsed AST.
Derived classes should not assume that `statement` contains a single statement,
and MUST explicitly validate that. Since this validation is parser dependent the
responsibility is left to the children classes.
"""
raise NotImplementedError()
@classmethod
def _extract_tables_from_statement(
cls,
parsed: InternalRepresentation,
engine: str,
) -> set[Table]:
"""
Extract all table references in a given statement.
"""
raise NotImplementedError()
def format(self, comments: bool = True) -> str:
"""
Format the statement, optionally ommitting comments.
"""
raise NotImplementedError()
def get_settings(self) -> dict[str, str | bool]:
"""
Return any settings set by the statement.
For example, for this statement:
sql> SET foo = 'bar';
The method should return `{"foo": "'bar'"}`. Note the single quotes.
"""
raise NotImplementedError()
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
This class is used for all engines with dialects that can be parsed using sqlglot.
"""
def __init__(
self,
statement: str | exp.Expression,
engine: str,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[SQLStatement]:
dialect = SQLGLOT_DIALECTS.get(engine)
try:
statements = sqlglot.parse(query, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
return [cls(statement, engine) for statement in statements if statement]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
# We could parse with `sqlglot.parse_one` to get a single statement, but we need
# to verify that the string contains exactly one statement.
try:
statements = sqlglot.parse(statement, dialect=dialect)
except sqlglot.errors.ParseError as ex:
raise SupersetParseError("Unable to split query") from ex
statements = [statement for statement in statements if statement]
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0]
@classmethod
def _extract_tables_from_statement(
cls,
parsed: exp.Expression,
engine: str,
) -> set[Table]:
"""
Find all referenced tables.
"""
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
for node in self._parsed.walk():
if isinstance(
node,
(
exp.Insert,
exp.Update,
exp.Delete,
exp.Merge,
exp.Create,
exp.Drop,
exp.TruncateTable,
),
):
return True
if isinstance(node, exp.Command) and node.name == "ALTER":
return True
# Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see
# https://www.postgresql.org/docs/current/sql-explain.html
if (
self._dialect == Dialects.POSTGRES
and isinstance(self._parsed, exp.Command)
and self._parsed.name == "EXPLAIN"
and self._parsed.expression.name.upper().startswith("ANALYZE ")
):
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
return SQLStatement(analyzed_sql, self.engine).is_mutating()
return False
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = SQLStatement("SET foo = 'bar'")
>>> statement.get_settings()
{"foo": "'bar'"}
"""
return {
eq.this.sql(): eq.expression.sql()
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
class KQLSplitState(enum.Enum):
"""
State machine for splitting a KQL query.
The state machine keeps track of whether we're inside a string or not, so we
don't split the query in a semi-colon that's part of a string.
"""
OUTSIDE_STRING = enum.auto()
INSIDE_SINGLE_QUOTED_STRING = enum.auto()
INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
INSIDE_MULTILINE_STRING = enum.auto()
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
query = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(query):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(query[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and query[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and query[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and query[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
return statements
class KustoKQLStatement(BaseSQLStatement[str]):
"""
Special class for Kusto KQL.
Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
like this:
StormEvents
| summarize PropertyDamage = sum(DamageProperty) by State
| join kind=innerunique PopulationData on State
| project State, PropertyDamagePerCapita = PropertyDamage / Population
| sort by PropertyDamagePerCapita
See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
details about it.
"""
@classmethod
def split_query(
cls,
query: str,
engine: str,
) -> list[KustoKQLStatement]:
"""
Split a query at semi-colons.
Since we don't have a parser, we use a simple state machine based function. See
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(query)]
@classmethod
def _parse_statement(
cls,
statement: str,
engine: str,
) -> str:
if engine != "kustokql":
raise SupersetParseError(f"Invalid engine: {engine}")
statements = split_kql(statement)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
return statements[0].strip()
@classmethod
def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
"""
Extract all tables referenced in the statement.
StormEvents
| where InjuriesDirect + InjuriesIndirect > 50
| join (PopulationData) on State
| project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
"""
logger.warning(
"Kusto KQL doesn't support table extraction. This means that data access "
"roles will not be enforced by Superset in the database."
)
return set()
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
>>> statement = KustoKQLStatement("set querytrace;")
>>> statement.get_settings()
{"querytrace": True}
"""
set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
if match := re.match(set_regex, self._parsed, re.IGNORECASE):
return {match.group("name"): match.group("value") or True}
return {}
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
:return: True if the statement mutates data.
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
# Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
# adds a lot of complexity to Superset, so we should avoid adding new engines to
# this data structure.
special_engines = {
"kustokql": KustoKQLStatement,
}
def __init__(
self,
query: str,
engine: str,
):
statement_class = self.special_engines.get(engine, SQLStatement)
self.statements = statement_class.split_query(query, engine)
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL query.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str | bool] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
def has_mutation(self) -> bool:
"""
Check if the script contains mutating statements.
:return: True if the script contains mutating statements
"""
return any(statement.is_mutating() for statement in self.statements)
class ParsedQuery:
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: str | None = None,
engine: str = "base",
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._engine = engine
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
@@ -854,24 +314,18 @@ class ParsedQuery:
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.stripped(), dialect=self._dialect)
except SqlglotError as ex:
statements = [
statement._parsed # pylint: disable=protected-access
for statement in SQLScript(self.stripped(), self._engine).statements
]
except SupersetParseError as ex:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
message = (
"Error parsing near '{highlight}' at line {line}:{col}".format( # pylint: disable=consider-using-f-string
**ex.errors[0]
)
if isinstance(ex, ParseError)
else str(ex)
)
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
message=__(
"You may have an error in your SQL statement. {message}"
).format(message=message),
).format(message=ex.error.message),
level=ErrorLevel.ERROR,
)
) from ex
@@ -883,77 +337,6 @@ class ParsedQuery:
if statement
}
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
if not (literal := statement.find(exp.Literal)):
return set()
try:
pseudo_query = parse_one(
f"SELECT {literal.this}",
dialect=self._dialect,
)
sources = pseudo_query.find_all(exp.Table)
except SqlglotError:
return set()
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
@property
def limit(self) -> int | None:
return self._limit