diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 714196de3d7..b653c88a2a0 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -23,7 +23,10 @@ from sqlparse.tokens import Keyword, Name RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'} ON_KEYWORD = 'ON' -PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} +PRECEDES_TABLE_NAME = { + 'FROM', 'JOIN', 'DESCRIBE', 'WITH', 'LEFT JOIN', 'RIGHT JOIN', +} +CTE_PREFIX = 'CTE__' class ParsedQuery(object): @@ -71,13 +74,6 @@ class ParsedQuery(object): statements.append(sql) return statements - @staticmethod - def __precedes_table_name(token_value): - for keyword in PRECEDES_TABLE_NAME: - if keyword in token_value: - return True - return False - @staticmethod def __get_full_name(identifier): if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.': @@ -85,21 +81,16 @@ class ParsedQuery(object): identifier.tokens[2].value) return identifier.get_real_name() - @staticmethod - def __is_result_operation(keyword): - for operation in RESULT_OPERATIONS: - if operation in keyword.upper(): - return True - return False - @staticmethod def __is_identifier(token): return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects - if '(' not in '{}'.format(identifier): - self._table_names.add(self.__get_full_name(identifier)) + if '(' not in str(identifier): + table_name = self.__get_full_name(identifier) + if not table_name.startswith(CTE_PREFIX): + self._table_names.add(self.__get_full_name(identifier)) return # store aliases @@ -129,39 +120,39 @@ class ParsedQuery(object): exec_sql += f'CREATE TABLE {table_name} AS \n{sql}' return exec_sql - def __extract_from_token(self, token): + def __extract_from_token(self, token, depth=0): if not hasattr(token, 'tokens'): return table_name_preceding_token = False for item in token.tokens: + logging.debug((' ' * depth) + str(item.ttype) + str(item.value)) if item.is_group and not self.__is_identifier(item): - self.__extract_from_token(item) + self.__extract_from_token(item, depth=depth + 1) - if item.ttype in Keyword: - if self.__precedes_table_name(item.value.upper()): - table_name_preceding_token = True - continue - - if not table_name_preceding_token: + if ( + item.ttype in Keyword and ( + item.normalized in PRECEDES_TABLE_NAME or + item.normalized.endswith(' JOIN') + )): + table_name_preceding_token = True continue - if item.ttype in Keyword or item.value == ',': - if (self.__is_result_operation(item.value) or - item.value.upper() == ON_KEYWORD): - table_name_preceding_token = False - continue - # FROM clause is over - break + if item.ttype in Keyword: + table_name_preceding_token = False + continue - if isinstance(item, Identifier): - self.__process_identifier(item) - - if isinstance(item, IdentifierList): - for token in item.tokens: - if self.__is_identifier(token): + if table_name_preceding_token: + if isinstance(item, Identifier): + self.__process_identifier(item) + elif isinstance(item, IdentifierList): + for token in item.get_identifiers(): self.__process_identifier(token) + elif isinstance(item, IdentifierList): + for token in item.tokens: + if not self.__is_identifier(token): + self.__extract_from_token(item, depth=depth + 1) def _extract_limit_from_query(self, statement): idx, _ = statement.token_next_by(m=(Keyword, 'LIMIT')) diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index 9247780a274..e821fce2545 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -167,7 +167,6 @@ class SupersetTestCase(unittest.TestCase): # DESCRIBE | DESC qualifiedName def test_describe(self): self.assertEquals({'t1'}, self.extract_tables('DESCRIBE t1')) - self.assertEquals({'t1'}, self.extract_tables('DESC t1')) # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? @@ -349,6 +348,32 @@ class SupersetTestCase(unittest.TestCase): {'table_a', 'table_b', 'table_c'}, self.extract_tables(query)) + def test_mixed_from_clause(self): + query = """SELECT * + FROM table_a AS a, (select * from table_b) AS b, table_c as c + WHERE a.id = b.id and b.id = c.id""" + self.assertEquals( + {'table_a', 'table_b', 'table_c'}, + self.extract_tables(query)) + + def test_nested_selects(self): + query = """ + select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) + from INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA like "%bi%"),0x7e))); + """ + self.assertEquals( + {'INFORMATION_SCHEMA.COLUMNS'}, + self.extract_tables(query)) + query = """ + select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) + from INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_NAME="bi_achivement_daily"),0x7e))); + """ + self.assertEquals( + {'INFORMATION_SCHEMA.COLUMNS'}, + self.extract_tables(query)) + def test_complex_extract_tables3(self): query = """SELECT somecol AS somecol FROM @@ -386,6 +411,21 @@ class SupersetTestCase(unittest.TestCase): {'a', 'b', 'c', 'd', 'e', 'f'}, self.extract_tables(query)) + def test_complex_cte_with_prefix(self): + query = """ + WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) + AS ( + SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear + FROM SalesOrderHeader + WHERE SalesPersonID IS NOT NULL + ) + SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear + FROM CTE__test + GROUP BY SalesYear, SalesPersonID + ORDER BY SalesPersonID, SalesYear; + """ + self.assertEquals({'SalesOrderHeader'}, self.extract_tables(query)) + def test_basic_breakdown_statements(self): multi_sql = """ SELECT * FROM ab_user;