mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
[sql lab] improve table name detection in free form SQL (#6793)
* [sql lab] improve table name detection in free form SQL * flake * Addressing comments
This commit is contained in:
committed by
GitHub
parent
fc4042a28b
commit
5a40f71710
@@ -23,7 +23,10 @@ from sqlparse.tokens import Keyword, Name
|
||||
|
||||
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
|
||||
ON_KEYWORD = 'ON'
|
||||
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
|
||||
PRECEDES_TABLE_NAME = {
|
||||
'FROM', 'JOIN', 'DESCRIBE', 'WITH', 'LEFT JOIN', 'RIGHT JOIN',
|
||||
}
|
||||
CTE_PREFIX = 'CTE__'
|
||||
|
||||
|
||||
class ParsedQuery(object):
|
||||
@@ -71,13 +74,6 @@ class ParsedQuery(object):
|
||||
statements.append(sql)
|
||||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __precedes_table_name(token_value):
|
||||
for keyword in PRECEDES_TABLE_NAME:
|
||||
if keyword in token_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(identifier):
|
||||
if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
|
||||
@@ -85,21 +81,16 @@ class ParsedQuery(object):
|
||||
identifier.tokens[2].value)
|
||||
return identifier.get_real_name()
|
||||
|
||||
@staticmethod
|
||||
def __is_result_operation(keyword):
|
||||
for operation in RESULT_OPERATIONS:
|
||||
if operation in keyword.upper():
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token):
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def __process_identifier(self, identifier):
|
||||
# exclude subselects
|
||||
if '(' not in '{}'.format(identifier):
|
||||
self._table_names.add(self.__get_full_name(identifier))
|
||||
if '(' not in str(identifier):
|
||||
table_name = self.__get_full_name(identifier)
|
||||
if not table_name.startswith(CTE_PREFIX):
|
||||
self._table_names.add(self.__get_full_name(identifier))
|
||||
return
|
||||
|
||||
# store aliases
|
||||
@@ -129,39 +120,39 @@ class ParsedQuery(object):
|
||||
exec_sql += f'CREATE TABLE {table_name} AS \n{sql}'
|
||||
return exec_sql
|
||||
|
||||
def __extract_from_token(self, token):
|
||||
def __extract_from_token(self, token, depth=0):
|
||||
if not hasattr(token, 'tokens'):
|
||||
return
|
||||
|
||||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
logging.debug((' ' * depth) + str(item.ttype) + str(item.value))
|
||||
if item.is_group and not self.__is_identifier(item):
|
||||
self.__extract_from_token(item)
|
||||
self.__extract_from_token(item, depth=depth + 1)
|
||||
|
||||
if item.ttype in Keyword:
|
||||
if self.__precedes_table_name(item.value.upper()):
|
||||
table_name_preceding_token = True
|
||||
continue
|
||||
|
||||
if not table_name_preceding_token:
|
||||
if (
|
||||
item.ttype in Keyword and (
|
||||
item.normalized in PRECEDES_TABLE_NAME or
|
||||
item.normalized.endswith(' JOIN')
|
||||
)):
|
||||
table_name_preceding_token = True
|
||||
continue
|
||||
|
||||
if item.ttype in Keyword or item.value == ',':
|
||||
if (self.__is_result_operation(item.value) or
|
||||
item.value.upper() == ON_KEYWORD):
|
||||
table_name_preceding_token = False
|
||||
continue
|
||||
# FROM clause is over
|
||||
break
|
||||
if item.ttype in Keyword:
|
||||
table_name_preceding_token = False
|
||||
continue
|
||||
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_identifier(item)
|
||||
|
||||
if isinstance(item, IdentifierList):
|
||||
for token in item.tokens:
|
||||
if self.__is_identifier(token):
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_identifier(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token in item.get_identifiers():
|
||||
self.__process_identifier(token)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token in item.tokens:
|
||||
if not self.__is_identifier(token):
|
||||
self.__extract_from_token(item, depth=depth + 1)
|
||||
|
||||
def _get_limit_from_token(self, token):
|
||||
if token.ttype == sqlparse.tokens.Literal.Number.Integer:
|
||||
|
||||
@@ -167,7 +167,6 @@ class SupersetTestCase(unittest.TestCase):
|
||||
# DESCRIBE | DESC qualifiedName
|
||||
def test_describe(self):
|
||||
self.assertEquals({'t1'}, self.extract_tables('DESCRIBE t1'))
|
||||
self.assertEquals({'t1'}, self.extract_tables('DESC t1'))
|
||||
|
||||
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
|
||||
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
|
||||
@@ -349,6 +348,32 @@ class SupersetTestCase(unittest.TestCase):
|
||||
{'table_a', 'table_b', 'table_c'},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_mixed_from_clause(self):
|
||||
query = """SELECT *
|
||||
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id"""
|
||||
self.assertEquals(
|
||||
{'table_a', 'table_b', 'table_c'},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_nested_selects(self):
|
||||
query = """
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||
"""
|
||||
self.assertEquals(
|
||||
{'INFORMATION_SCHEMA.COLUMNS'},
|
||||
self.extract_tables(query))
|
||||
query = """
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
|
||||
"""
|
||||
self.assertEquals(
|
||||
{'INFORMATION_SCHEMA.COLUMNS'},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_complex_extract_tables3(self):
|
||||
query = """SELECT somecol AS somecol
|
||||
FROM
|
||||
@@ -386,6 +411,21 @@ class SupersetTestCase(unittest.TestCase):
|
||||
{'a', 'b', 'c', 'd', 'e', 'f'},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_complex_cte_with_prefix(self):
|
||||
query = """
|
||||
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
||||
AS (
|
||||
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
||||
FROM SalesOrderHeader
|
||||
WHERE SalesPersonID IS NOT NULL
|
||||
)
|
||||
SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
|
||||
FROM CTE__test
|
||||
GROUP BY SalesYear, SalesPersonID
|
||||
ORDER BY SalesPersonID, SalesYear;
|
||||
"""
|
||||
self.assertEquals({'SalesOrderHeader'}, self.extract_tables(query))
|
||||
|
||||
def test_basic_breakdown_statements(self):
|
||||
multi_sql = """
|
||||
SELECT * FROM ab_user;
|
||||
|
||||
Reference in New Issue
Block a user