[sql lab] improve table name detection in free form SQL (#6793)

* [sql lab] improve table name detection in free form SQL

* flake

* Addressing comments

(cherry picked from commit 5a40f71710)
This commit is contained in:
Maxime Beauchemin
2019-02-04 16:03:23 -08:00
committed by Grace Guo
parent 2357c4aabf
commit b64a452a6d
2 changed files with 70 additions and 39 deletions

View File

@@ -23,7 +23,10 @@ from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
PRECEDES_TABLE_NAME = {
'FROM', 'JOIN', 'DESCRIBE', 'WITH', 'LEFT JOIN', 'RIGHT JOIN',
}
CTE_PREFIX = 'CTE__'
class ParsedQuery(object):
@@ -71,13 +74,6 @@ class ParsedQuery(object):
statements.append(sql)
return statements
@staticmethod
def __precedes_table_name(token_value):
for keyword in PRECEDES_TABLE_NAME:
if keyword in token_value:
return True
return False
@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
@@ -85,21 +81,16 @@ class ParsedQuery(object):
identifier.tokens[2].value)
return identifier.get_real_name()
@staticmethod
def __is_result_operation(keyword):
for operation in RESULT_OPERATIONS:
if operation in keyword.upper():
return True
return False
@staticmethod
def __is_identifier(token):
return isinstance(token, (IdentifierList, Identifier))
def __process_identifier(self, identifier):
# exclude subselects
if '(' not in '{}'.format(identifier):
self._table_names.add(self.__get_full_name(identifier))
if '(' not in str(identifier):
table_name = self.__get_full_name(identifier)
if not table_name.startswith(CTE_PREFIX):
self._table_names.add(self.__get_full_name(identifier))
return
# store aliases
@@ -129,39 +120,39 @@ class ParsedQuery(object):
exec_sql += f'CREATE TABLE {table_name} AS \n{sql}'
return exec_sql
def __extract_from_token(self, token):
def __extract_from_token(self, token, depth=0):
if not hasattr(token, 'tokens'):
return
table_name_preceding_token = False
for item in token.tokens:
logging.debug((' ' * depth) + str(item.ttype) + str(item.value))
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
self.__extract_from_token(item, depth=depth + 1)
if item.ttype in Keyword:
if self.__precedes_table_name(item.value.upper()):
table_name_preceding_token = True
continue
if not table_name_preceding_token:
if (
item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME or
item.normalized.endswith(' JOIN')
)):
table_name_preceding_token = True
continue
if item.ttype in Keyword or item.value == ',':
if (self.__is_result_operation(item.value) or
item.value.upper() == ON_KEYWORD):
table_name_preceding_token = False
continue
# FROM clause is over
break
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if isinstance(item, Identifier):
self.__process_identifier(item)
if isinstance(item, IdentifierList):
for token in item.tokens:
if self.__is_identifier(token):
if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_identifier(item)
elif isinstance(item, IdentifierList):
for token in item.get_identifiers():
self.__process_identifier(token)
elif isinstance(item, IdentifierList):
for token in item.tokens:
if not self.__is_identifier(token):
self.__extract_from_token(item, depth=depth + 1)
def _extract_limit_from_query(self, statement):
idx, _ = statement.token_next_by(m=(Keyword, 'LIMIT'))