Files
superset2/superset/sql_parse.py
Will Barrett 562aeab1aa Fix a bunch of files with pylint disabled (#8743)
* Re-enable pylint for superset/jinja_context.py

* Re-enable pylint for superset/sql_lab.py

* Re-enable pylint for superset/sql_parse.py

* Re-enable pylint for superset/exceptions.py

* Re-enable lint for superset/translations/utils.py

* Re-enable pylint for superset/views/schedules.py

* Re-enable pylint for superset/views/base.py

* Re-enable pylint for superset/views/log/views.py

* Re-enable pylint for superset/views/annotations.py

* black

* PR feedback, pylint, isort fixes

* Black, one more time...

* Move ungrouped-imports to a global disable
2019-12-11 10:14:24 -08:00

234 lines
8.4 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional, Set
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
ON_KEYWORD = "ON"
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
"""
Extract limit clause from SQL statement.
:param statement: SQL statement
:return: Limit extracted from query, None if no limit present in statement
"""
idx, _ = statement.token_next_by(m=(Keyword, "LIMIT"))
if idx is not None:
_, token = statement.token_next(idx=idx)
if token:
if isinstance(token, IdentifierList):
# In case of "LIMIT <offset>, <limit>", find comma and extract
# first succeeding non-whitespace token
idx, _ = token.token_next_by(m=(sqlparse.tokens.Punctuation, ","))
_, token = token.token_next(idx=idx)
if token and token.ttype == sqlparse.tokens.Literal.Number.Integer:
return int(token.value)
return None
class ParsedQuery(object):
def __init__(self, sql_statement):
self.sql: str = sql_statement
self._table_names: Set[str] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
logging.info("Parsing with sqlparse statement %s", self.sql)
self._parsed = sqlparse.parse(self.stripped())
for statement in self._parsed:
self.__extract_from_token(statement)
self._limit = _extract_limit_from_query(statement)
self._table_names = self._table_names - self._alias_names
@property
def tables(self) -> Set[str]:
return self._table_names
@property
def limit(self) -> Optional[int]:
return self._limit
def is_select(self) -> bool:
return self._parsed[0].get_type() == "SELECT"
def is_explain(self) -> bool:
return self.stripped().upper().startswith("EXPLAIN")
def is_readonly(self) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return self.is_select() or self.is_explain()
def stripped(self) -> str:
return self.sql.strip(" \t\n;")
def get_statements(self) -> List[str]:
"""Returns a list of SQL statements as strings, stripped"""
statements = []
for statement in self._parsed:
if statement:
sql = str(statement).strip(" \n;\t")
if sql:
statements.append(sql)
return statements
@staticmethod
def __get_full_name(tlist: TokenList) -> Optional[str]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
:param tlist: The SQL tokens
:returns: The valid full table name
"""
# Strip the alias if present.
idx = len(tlist.tokens)
if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)
if ws_idx != -1:
idx = ws_idx
tokens = tlist.tokens[:idx]
if (
len(tokens) in (1, 3, 5)
and all(imt(token, t=[Name, String]) for token in tokens[0::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
):
return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
return None
@staticmethod
def __is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def __process_tokenlist(self, token_list: TokenList):
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table_name = self.__get_full_name(token_list)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
return
# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self.__extract_from_token(token_list)
def as_create_table(self, table_name: str, overwrite: bool = False) -> str:
"""Reformats the query into the create table as query.
Works only for the single select SQL statements, in all other cases
the sql query is not modified.
:param table_name: Table that will contain the results of the query execution
:param overwrite: table_name will be dropped if true
:return: Create table as query
"""
exec_sql = ""
sql = self.stripped()
if overwrite:
exec_sql = f"DROP TABLE IF EXISTS {table_name};\n"
exec_sql += f"CREATE TABLE {table_name} AS \n{sql}"
return exec_sql
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
"""
Populate self._table_names from token
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self.__process_tokenlist(token2)
elif isinstance(item, IdentifierList):
for token2 in item.tokens:
if not self.__is_identifier(token2):
self.__extract_from_token(item)
def get_query_with_new_limit(self, new_limit: int) -> str:
"""
returns the query with the specified limit.
Does not change the underlying query
:param new_limit: Limit to be incorporated into returned query
:return: The original query with new limit
"""
if not self._limit:
return f"{self.stripped()}\nLIMIT {new_limit}"
limit_pos = None
statement = self._parsed[0]
# Add all items to before_str until there is a limit
for pos, item in enumerate(statement.tokens):
if item.ttype in Keyword and item.value.lower() == "limit":
limit_pos = pos
break
_, limit = statement.token_next(idx=limit_pos)
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
limit.value = new_limit
elif limit.is_group:
limit.value = f"{next(limit.get_identifiers())}, {new_limit}"
str_res = ""
for i in statement.tokens:
str_res += str(i.value)
return str_res