mirror of
https://github.com/apache/superset.git
synced 2026-04-26 11:34:27 +00:00
[sql] Adding lighweight Table class (#9649)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
@@ -16,8 +16,10 @@
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from dataclasses import dataclass
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
@@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: Optional[str] = None
|
||||
catalog: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
"""
|
||||
|
||||
return ".".join(
|
||||
parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str):
|
||||
self.sql: str = sql_statement
|
||||
self._table_names: Set[str] = set()
|
||||
self._tables: Set[Table] = set()
|
||||
self._alias_names: Set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
|
||||
@@ -70,12 +94,15 @@ class ParsedQuery:
|
||||
self._limit = _extract_limit_from_query(statement)
|
||||
|
||||
@property
|
||||
def tables(self) -> Set[str]:
|
||||
if not self._table_names:
|
||||
def tables(self) -> Set[Table]:
|
||||
if not self._tables:
|
||||
for statement in self._parsed:
|
||||
self.__extract_from_token(statement)
|
||||
self._table_names = self._table_names - self._alias_names
|
||||
return self._table_names
|
||||
self._extract_from_token(statement)
|
||||
|
||||
self._tables = {
|
||||
table for table in self._tables if str(table) not in self._alias_names
|
||||
}
|
||||
return self._tables
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
@@ -105,13 +132,13 @@ class ParsedQuery:
|
||||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(tlist: TokenList) -> Optional[str]:
|
||||
def _get_table(tlist: TokenList) -> Optional[Table]:
|
||||
"""
|
||||
Return the full unquoted table name if valid, i.e., conforms to the following
|
||||
[[cluster.]schema.]table construct.
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
|
||||
:param tlist: The SQL tokens
|
||||
:returns: The valid full table name
|
||||
:returns: The table if the name conforms
|
||||
"""
|
||||
|
||||
# Strip the alias if present.
|
||||
@@ -127,18 +154,18 @@ class ParsedQuery:
|
||||
|
||||
if (
|
||||
len(tokens) in (1, 3, 5)
|
||||
and all(imt(token, t=[Name, String]) for token in tokens[0::2])
|
||||
and all(imt(token, t=[Name, String]) for token in tokens[::2])
|
||||
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
|
||||
):
|
||||
return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
|
||||
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token: Token) -> bool:
|
||||
def _is_identifier(token: Token) -> bool:
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def __process_tokenlist(self, token_list: TokenList):
|
||||
def _process_tokenlist(self, token_list: TokenList):
|
||||
"""
|
||||
Add table names to table set
|
||||
|
||||
@@ -146,9 +173,9 @@ class ParsedQuery:
|
||||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table_name = self.__get_full_name(token_list)
|
||||
if table_name and not table_name.startswith(CTE_PREFIX):
|
||||
self._table_names.add(table_name)
|
||||
table = self._get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
|
||||
# store aliases
|
||||
@@ -158,7 +185,7 @@ class ParsedQuery:
|
||||
# some aliases are not parsed properly
|
||||
if token_list.tokens[0].ttype == Name:
|
||||
self._alias_names.add(token_list.tokens[0].value)
|
||||
self.__extract_from_token(token_list)
|
||||
self._extract_from_token(token_list)
|
||||
|
||||
def as_create_table(
|
||||
self,
|
||||
@@ -184,9 +211,9 @@ class ParsedQuery:
|
||||
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
|
||||
return exec_sql
|
||||
|
||||
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
||||
def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
||||
"""
|
||||
Populate self._table_names from token
|
||||
Populate self._tables from token
|
||||
|
||||
:param token: instance of Token or child class, e.g. TokenList, to be processed
|
||||
"""
|
||||
@@ -196,8 +223,8 @@ class ParsedQuery:
|
||||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
if item.is_group and not self.__is_identifier(item):
|
||||
self.__extract_from_token(item)
|
||||
if item.is_group and not self._is_identifier(item):
|
||||
self._extract_from_token(item)
|
||||
|
||||
if item.ttype in Keyword and (
|
||||
item.normalized in PRECEDES_TABLE_NAME
|
||||
@@ -212,15 +239,15 @@ class ParsedQuery:
|
||||
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_tokenlist(item)
|
||||
self._process_tokenlist(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.get_identifiers():
|
||||
if isinstance(token2, TokenList):
|
||||
self.__process_tokenlist(token2)
|
||||
self._process_tokenlist(token2)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.tokens:
|
||||
if not self.__is_identifier(token2):
|
||||
self.__extract_from_token(item)
|
||||
if not self._is_identifier(token2):
|
||||
self._extract_from_token(item)
|
||||
|
||||
def set_or_update_query_limit(self, new_limit: int) -> str:
|
||||
"""Returns the query with the specified limit.
|
||||
|
||||
Reference in New Issue
Block a user