mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
chore: improve SQL parsing (#26767)
This commit is contained in:
@@ -22,13 +22,14 @@ import re
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects import Dialects
|
||||
from sqlglot.errors import SqlglotError
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError, SqlglotError
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
@@ -55,7 +56,7 @@ from sqlparse.tokens import (
|
||||
)
|
||||
from sqlparse.utils import imt
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
||||
from superset.utils.backports import StrEnum
|
||||
|
||||
try:
|
||||
@@ -252,6 +253,185 @@ class Table:
|
||||
return str(self) == str(__o)
|
||||
|
||||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
dialect: Optional[Dialects],
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||
|
||||
See the unit tests for other tricky cases.
|
||||
"""
|
||||
sources: Iterable[exp.Table]
|
||||
|
||||
if isinstance(statement, exp.Describe):
|
||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||
# query for all tables.
|
||||
sources = statement.find_all(exp.Table)
|
||||
elif isinstance(statement, exp.Command):
|
||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||
# `SELECT` statetement in order to extract tables.
|
||||
literal = statement.find(exp.Literal)
|
||||
if not literal:
|
||||
return set()
|
||||
|
||||
try:
|
||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
|
||||
except ParseError:
|
||||
return set()
|
||||
sources = pseudo_query.find_all(exp.Table)
|
||||
else:
|
||||
sources = [
|
||||
source
|
||||
for scope in traverse_scope(statement)
|
||||
for source in scope.sources.values()
|
||||
if isinstance(source, exp.Table) and not is_cte(source, scope)
|
||||
]
|
||||
|
||||
return {
|
||||
Table(
|
||||
source.name,
|
||||
source.db if source.db != "" else None,
|
||||
source.catalog if source.catalog != "" else None,
|
||||
)
|
||||
for source in sources
|
||||
}
|
||||
|
||||
|
||||
def is_cte(source: exp.Table, scope: Scope) -> bool:
|
||||
"""
|
||||
Is the source a CTE?
|
||||
|
||||
CTEs in the parent scope look like tables (and are represented by
|
||||
exp.Table objects), but should not be considered as such;
|
||||
otherwise a user with access to table `foo` could access any table
|
||||
with a query like this:
|
||||
|
||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||
|
||||
"""
|
||||
parent_sources = scope.parent.sources if scope.parent else {}
|
||||
ctes_in_scope = {
|
||||
name
|
||||
for name, parent_scope in parent_sources.items()
|
||||
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
|
||||
}
|
||||
|
||||
return source.name in ctes_in_scope
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
A SQL script, with 0+ statements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
self.statements = [
|
||||
SQLStatement(statement, engine=engine)
|
||||
for statement in parse(query, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL query.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
"""
|
||||
Return the settings for the SQL query.
|
||||
|
||||
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'baz'"}
|
||||
|
||||
"""
|
||||
settings: dict[str, str] = {}
|
||||
for statement in self.statements:
|
||||
settings.update(statement.get_settings())
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class SQLStatement:
|
||||
"""
|
||||
A SQL statement.
|
||||
|
||||
This class provides helper methods to manipulate and introspect SQL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: Union[str, exp.Expression],
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
if isinstance(statement, str):
|
||||
try:
|
||||
self._parsed = self._parse_statement(statement, dialect)
|
||||
except ParseError as ex:
|
||||
raise SupersetParseError(statement, engine) from ex
|
||||
else:
|
||||
self._parsed = statement
|
||||
|
||||
self._dialect = dialect
|
||||
self.tables = extract_tables_from_statement(self._parsed, dialect)
|
||||
|
||||
@staticmethod
|
||||
def _parse_statement(
|
||||
sql_statement: str,
|
||||
dialect: Optional[Dialects],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Parse a single SQL statement.
|
||||
"""
|
||||
statements = [
|
||||
statement
|
||||
for statement in sqlglot.parse(sql_statement, dialect=dialect)
|
||||
if statement
|
||||
]
|
||||
if len(statements) != 1:
|
||||
raise ValueError("SQLStatement should have exactly one statement")
|
||||
|
||||
return statements[0]
|
||||
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
|
||||
def get_settings(self) -> dict[str, str]:
|
||||
"""
|
||||
Return the settings for the SQL statement.
|
||||
|
||||
>>> statement = SQLStatement("SET foo = 'bar'")
|
||||
>>> statement.get_settings()
|
||||
{"foo": "'bar'"}
|
||||
|
||||
"""
|
||||
return {
|
||||
eq.this.sql(): eq.expression.sql()
|
||||
for set_item in self._parsed.find_all(exp.SetItem)
|
||||
for eq in set_item.find_all(exp.EQ)
|
||||
}
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -294,7 +474,7 @@ class ParsedQuery:
|
||||
return {
|
||||
table
|
||||
for statement in statements
|
||||
for table in self._extract_tables_from_statement(statement)
|
||||
for table in extract_tables_from_statement(statement, self._dialect)
|
||||
if statement
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user