chore: improve SQL parsing (#26767)

This commit is contained in:
Beto Dealmeida
2024-03-13 18:27:01 -04:00
committed by GitHub
parent a75bb7685d
commit 26d8077e97
27 changed files with 393 additions and 195 deletions

View File

@@ -22,13 +22,14 @@ import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from typing import Any, cast, Optional, Union
import sqlglot
import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import SqlglotError
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError, SqlglotError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
@@ -55,7 +56,7 @@ from sqlparse.tokens import (
)
from sqlparse.utils import imt
from superset.exceptions import QueryClauseValidationException
from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.utils.backports import StrEnum
try:
@@ -252,6 +253,185 @@ class Table:
return str(self) == str(__o)
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Optional[Dialects],
) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
try:
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
except ParseError:
return set()
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def is_cte(source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
class SQLScript:
"""
A SQL script, with 0+ statements.
"""
def __init__(
self,
query: str,
engine: Optional[str] = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self.statements = [
SQLStatement(statement, engine=engine)
for statement in parse(query, dialect=dialect)
if statement
]
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL query.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)
def get_settings(self) -> dict[str, str]:
"""
Return the settings for the SQL query.
>>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
>>> statement.get_settings()
{"foo": "'baz'"}
"""
settings: dict[str, str] = {}
for statement in self.statements:
settings.update(statement.get_settings())
return settings
class SQLStatement:
"""
A SQL statement.
This class provides helper methods to manipulate and introspect SQL.
"""
def __init__(
self,
statement: Union[str, exp.Expression],
engine: Optional[str] = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
if isinstance(statement, str):
try:
self._parsed = self._parse_statement(statement, dialect)
except ParseError as ex:
raise SupersetParseError(statement, engine) from ex
else:
self._parsed = statement
self._dialect = dialect
self.tables = extract_tables_from_statement(self._parsed, dialect)
@staticmethod
def _parse_statement(
sql_statement: str,
dialect: Optional[Dialects],
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
statements = [
statement
for statement in sqlglot.parse(sql_statement, dialect=dialect)
if statement
]
if len(statements) != 1:
raise ValueError("SQLStatement should have exactly one statement")
return statements[0]
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
def get_settings(self) -> dict[str, str]:
"""
Return the settings for the SQL statement.
>>> statement = SQLStatement("SET foo = 'bar'")
>>> statement.get_settings()
{"foo": "'bar'"}
"""
return {
eq.this.sql(): eq.expression.sql()
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
class ParsedQuery:
def __init__(
self,
@@ -294,7 +474,7 @@ class ParsedQuery:
return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
for table in extract_tables_from_statement(statement, self._dialect)
if statement
}