feat(sqlparse): improve table parsing (#26476)

(cherry picked from commit c0b57bd1c3)
This commit is contained in:
Beto Dealmeida
2024-01-22 11:16:50 -05:00
committed by Michael S. Molina
parent 6cdaf479f2
commit 1d9cfdabd1
17 changed files with 265 additions and 120 deletions

View File

@@ -14,15 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
import logging
import re
from collections.abc import Iterator
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from urllib import parse
import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
try:
from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except
except (ImportError, ModuleNotFoundError):
sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
lex.set_SQL_REGEX(sqlparser_sql_regex)
# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
"ascend": Dialects.HIVE,
"awsathena": Dialects.PRESTO,
"bigquery": Dialects.BIGQUERY,
"clickhouse": Dialects.CLICKHOUSE,
"clickhousedb": Dialects.CLICKHOUSE,
"cockroachdb": Dialects.POSTGRES,
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
# "firebolt": ???
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
# "ibmi": ???
# "impala": ???
# "kustokql": ???
# "kylin": ???
# "mssql": ???
"mysql": Dialects.MYSQL,
"netezza": Dialects.POSTGRES,
# "ocient": ???
# "odelasticsearch": ???
"oracle": Dialects.ORACLE,
# "pinot": ???
"postgresql": Dialects.POSTGRES,
"presto": Dialects.PRESTO,
"pydoris": Dialects.DORIS,
"redshift": Dialects.REDSHIFT,
# "risingwave": ???
# "rockset": ???
"shillelagh": Dialects.SQLITE,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
"superset": Dialects.SQLITE,
"teradatasql": Dialects.TERADATA,
"trino": Dialects.TRINO,
"vertica": Dialects.POSTGRES,
}
class CtasMethod(StrEnum):
TABLE = "TABLE"
VIEW = "VIEW"
@@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
def strip_comments_from_sql(statement: str) -> str:
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
return (
ParsedQuery(statement, engine=engine).strip_comments()
if "--" in statement
else statement
)
@dataclass(eq=True, frozen=True)
@@ -179,7 +243,7 @@ class Table:
"""
return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
@@ -189,11 +253,17 @@ class Table:
class ParsedQuery:
def __init__(self, sql_statement: str, strip_comments: bool = False):
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: Optional[str] = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
@@ -206,14 +276,94 @@ class ParsedQuery:
@property
def tables(self) -> set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)
self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
self._tables = self._extract_tables_from_sql()
return self._tables
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.sql, dialect=self._dialect)
except ParseError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()
return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
if statement
}
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
@property
def limit(self) -> Optional[int]:
return self._limit
@@ -393,28 +543,6 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)
def as_create_table(
self,
table_name: str,
@@ -441,50 +569,6 @@ class ParsedQuery:
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql
def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
@@ -881,7 +965,7 @@ def insert_rls_in_predicate(
# mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = {
SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
@@ -914,7 +998,7 @@ def extract_table_references(
tree = None
if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)