mirror of
https://github.com/apache/superset.git
synced 2026-05-10 10:25:51 +00:00
feat(sqlparse): improve table parsing (#26476)
(cherry picked from commit c0b57bd1c3)
This commit is contained in:
committed by
Michael S. Molina
parent
6cdaf479f2
commit
1d9cfdabd1
@@ -14,15 +14,22 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects import Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
from sqlparse.sql import (
|
||||
@@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
|
||||
|
||||
try:
|
||||
from sqloxide import parse_sql as sqloxide_parse
|
||||
except: # pylint: disable=bare-except
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
sqloxide_parse = None
|
||||
|
||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||
@@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
|
||||
lex.set_SQL_REGEX(sqlparser_sql_regex)
|
||||
|
||||
|
||||
# mapping between DB engine specs and sqlglot dialects
|
||||
SQLGLOT_DIALECTS = {
|
||||
"ascend": Dialects.HIVE,
|
||||
"awsathena": Dialects.PRESTO,
|
||||
"bigquery": Dialects.BIGQUERY,
|
||||
"clickhouse": Dialects.CLICKHOUSE,
|
||||
"clickhousedb": Dialects.CLICKHOUSE,
|
||||
"cockroachdb": Dialects.POSTGRES,
|
||||
# "crate": ???
|
||||
# "databend": ???
|
||||
"databricks": Dialects.DATABRICKS,
|
||||
# "db2": ???
|
||||
# "dremio": ???
|
||||
"drill": Dialects.DRILL,
|
||||
# "druid": ???
|
||||
"duckdb": Dialects.DUCKDB,
|
||||
# "dynamodb": ???
|
||||
# "elasticsearch": ???
|
||||
# "exa": ???
|
||||
# "firebird": ???
|
||||
# "firebolt": ???
|
||||
"gsheets": Dialects.SQLITE,
|
||||
"hana": Dialects.POSTGRES,
|
||||
"hive": Dialects.HIVE,
|
||||
# "ibmi": ???
|
||||
# "impala": ???
|
||||
# "kustokql": ???
|
||||
# "kylin": ???
|
||||
# "mssql": ???
|
||||
"mysql": Dialects.MYSQL,
|
||||
"netezza": Dialects.POSTGRES,
|
||||
# "ocient": ???
|
||||
# "odelasticsearch": ???
|
||||
"oracle": Dialects.ORACLE,
|
||||
# "pinot": ???
|
||||
"postgresql": Dialects.POSTGRES,
|
||||
"presto": Dialects.PRESTO,
|
||||
"pydoris": Dialects.DORIS,
|
||||
"redshift": Dialects.REDSHIFT,
|
||||
# "risingwave": ???
|
||||
# "rockset": ???
|
||||
"shillelagh": Dialects.SQLITE,
|
||||
"snowflake": Dialects.SNOWFLAKE,
|
||||
# "solr": ???
|
||||
"sqlite": Dialects.SQLITE,
|
||||
"starrocks": Dialects.STARROCKS,
|
||||
"superset": Dialects.SQLITE,
|
||||
"teradatasql": Dialects.TERADATA,
|
||||
"trino": Dialects.TRINO,
|
||||
"vertica": Dialects.POSTGRES,
|
||||
}
|
||||
|
||||
|
||||
class CtasMethod(StrEnum):
|
||||
TABLE = "TABLE"
|
||||
VIEW = "VIEW"
|
||||
@@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||
return cte, remainder
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str) -> str:
|
||||
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
||||
"""
|
||||
Strips comments from a SQL statement, does a simple test first
|
||||
to avoid always instantiating the expensive ParsedQuery constructor
|
||||
@@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
|
||||
:param statement: A string with the SQL statement
|
||||
:return: SQL statement without comments
|
||||
"""
|
||||
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
|
||||
return (
|
||||
ParsedQuery(statement, engine=engine).strip_comments()
|
||||
if "--" in statement
|
||||
else statement
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
@@ -179,7 +243,7 @@ class Table:
|
||||
"""
|
||||
|
||||
return ".".join(
|
||||
parse.quote(part, safe="").replace(".", "%2E")
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
@@ -189,11 +253,17 @@ class Table:
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str, strip_comments: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
sql_statement: str,
|
||||
strip_comments: bool = False,
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
if strip_comments:
|
||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||
|
||||
self.sql: str = sql_statement
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._tables: set[Table] = set()
|
||||
self._alias_names: set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
@@ -206,14 +276,94 @@ class ParsedQuery:
|
||||
@property
|
||||
def tables(self) -> set[Table]:
|
||||
if not self._tables:
|
||||
for statement in self._parsed:
|
||||
self._extract_from_token(statement)
|
||||
|
||||
self._tables = {
|
||||
table for table in self._tables if str(table) not in self._alias_names
|
||||
}
|
||||
self._tables = self._extract_tables_from_sql()
|
||||
return self._tables
|
||||
|
||||
def _extract_tables_from_sql(self) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a query.
|
||||
|
||||
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||
"""
|
||||
try:
|
||||
statements = parse(self.sql, dialect=self._dialect)
|
||||
except ParseError:
|
||||
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||
return set()
|
||||
|
||||
return {
|
||||
table
|
||||
for statement in statements
|
||||
for table in self._extract_tables_from_statement(statement)
|
||||
if statement
|
||||
}
|
||||
|
||||
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||
|
||||
See the unit tests for other tricky cases.
|
||||
"""
|
||||
sources: Iterable[exp.Table]
|
||||
|
||||
if isinstance(statement, exp.Describe):
|
||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||
# query for all tables.
|
||||
sources = statement.find_all(exp.Table)
|
||||
elif isinstance(statement, exp.Command):
|
||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||
# `SELECT` statetement in order to extract tables.
|
||||
literal = statement.find(exp.Literal)
|
||||
if not literal:
|
||||
return set()
|
||||
|
||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
|
||||
sources = pseudo_query.find_all(exp.Table)
|
||||
else:
|
||||
sources = [
|
||||
source
|
||||
for scope in traverse_scope(statement)
|
||||
for source in scope.sources.values()
|
||||
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
|
||||
]
|
||||
|
||||
return {
|
||||
Table(
|
||||
source.name,
|
||||
source.db if source.db != "" else None,
|
||||
source.catalog if source.catalog != "" else None,
|
||||
)
|
||||
for source in sources
|
||||
}
|
||||
|
||||
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
|
||||
"""
|
||||
Is the source a CTE?
|
||||
|
||||
CTEs in the parent scope look like tables (and are represented by
|
||||
exp.Table objects), but should not be considered as such;
|
||||
otherwise a user with access to table `foo` could access any table
|
||||
with a query like this:
|
||||
|
||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||
|
||||
"""
|
||||
parent_sources = scope.parent.sources if scope.parent else {}
|
||||
ctes_in_scope = {
|
||||
name
|
||||
for name, parent_scope in parent_sources.items()
|
||||
if isinstance(parent_scope, Scope)
|
||||
and parent_scope.scope_type == ScopeType.CTE
|
||||
}
|
||||
|
||||
return source.name in ctes_in_scope
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
return self._limit
|
||||
@@ -393,28 +543,6 @@ class ParsedQuery:
|
||||
def _is_identifier(token: Token) -> bool:
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def _process_tokenlist(self, token_list: TokenList) -> None:
|
||||
"""
|
||||
Add table names to table set
|
||||
|
||||
:param token_list: TokenList to be processed
|
||||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table = self.get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
|
||||
# store aliases
|
||||
if token_list.has_alias():
|
||||
self._alias_names.add(token_list.get_alias())
|
||||
|
||||
# some aliases are not parsed properly
|
||||
if token_list.tokens[0].ttype == Name:
|
||||
self._alias_names.add(token_list.tokens[0].value)
|
||||
self._extract_from_token(token_list)
|
||||
|
||||
def as_create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
@@ -441,50 +569,6 @@ class ParsedQuery:
|
||||
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
|
||||
return exec_sql
|
||||
|
||||
def _extract_from_token(self, token: Token) -> None:
|
||||
"""
|
||||
<Identifier> store a list of subtokens and <IdentifierList> store lists of
|
||||
subtoken list.
|
||||
|
||||
It extracts <IdentifierList> and <Identifier> from :param token: and loops
|
||||
through all subtokens recursively. It finds table_name_preceding_token and
|
||||
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
|
||||
self._tables.
|
||||
|
||||
:param token: instance of Token or child class, e.g. TokenList, to be processed
|
||||
"""
|
||||
if not hasattr(token, "tokens"):
|
||||
return
|
||||
|
||||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
if item.is_group and (
|
||||
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
|
||||
):
|
||||
self._extract_from_token(item)
|
||||
|
||||
if item.ttype in Keyword and (
|
||||
item.normalized in PRECEDES_TABLE_NAME
|
||||
or item.normalized.endswith(" JOIN")
|
||||
):
|
||||
table_name_preceding_token = True
|
||||
continue
|
||||
|
||||
if item.ttype in Keyword:
|
||||
table_name_preceding_token = False
|
||||
continue
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self._process_tokenlist(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.get_identifiers():
|
||||
if isinstance(token2, TokenList):
|
||||
self._process_tokenlist(token2)
|
||||
elif isinstance(item, IdentifierList):
|
||||
if any(not self._is_identifier(token2) for token2 in item.tokens):
|
||||
self._extract_from_token(item)
|
||||
|
||||
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
|
||||
"""Returns the query with the specified limit.
|
||||
|
||||
@@ -881,7 +965,7 @@ def insert_rls_in_predicate(
|
||||
|
||||
|
||||
# mapping between sqloxide and SQLAlchemy dialects
|
||||
SQLOXITE_DIALECTS = {
|
||||
SQLOXIDE_DIALECTS = {
|
||||
"ansi": {"trino", "trinonative", "presto"},
|
||||
"hive": {"hive", "databricks"},
|
||||
"ms": {"mssql"},
|
||||
@@ -914,7 +998,7 @@ def extract_table_references(
|
||||
tree = None
|
||||
|
||||
if sqloxide_parse:
|
||||
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
|
||||
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
|
||||
if sqla_dialect in sqla_dialects:
|
||||
break
|
||||
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
|
||||
|
||||
Reference in New Issue
Block a user