[sqllab] Fix sqllab limit regex issue with sqlparse (#5295)

* include items after limit to the modified query

* use sqlparse
This commit is contained in:
timifasubaa
2018-07-16 15:27:30 -07:00
committed by GitHub
parent c445ef8c43
commit f8a6e09220
3 changed files with 115 additions and 49 deletions

View File

@@ -20,18 +20,24 @@ class SupersetQuery(object):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
self._limit = None
# TODO: multistatement support
logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
for statement in self._parsed:
self.__extract_from_token(statement)
self._limit = self._extract_limit_from_query(statement)
self._table_names = self._table_names - self._alias_names
@property
def tables(self):
return self._table_names
@property
def limit(self):
return self._limit
def is_select(self):
return self._parsed[0].get_type() == 'SELECT'
@@ -128,3 +134,41 @@ class SupersetQuery(object):
for token in item.tokens:
if self.__is_identifier(token):
self.__process_identifier(token)
def _get_limit_from_token(self, token):
if token.ttype == sqlparse.tokens.Literal.Number.Integer:
return int(token.value)
elif token.is_group:
return int(token.get_token_at_offset(1).value)
def _extract_limit_from_query(self, statement):
limit_token = None
for pos, item in enumerate(statement.tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_token = statement.tokens[pos + 2]
return self._get_limit_from_token(limit_token)
def get_query_with_new_limit(self, new_limit):
"""returns the query with the specified limit"""
"""does not change the underlying query"""
if not self._limit:
return self.sql + ' LIMIT ' + str(new_limit)
limit_pos = None
tokens = self._parsed[0].tokens
# Add all items to before_str until there is a limit
for pos, item in enumerate(tokens):
if item.ttype in Keyword and item.value.lower() == 'limit':
limit_pos = pos
break
limit = tokens[limit_pos + 2]
if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
tokens[limit_pos + 2].value = new_limit
elif limit.is_group:
tokens[limit_pos + 2].value = (
'{}, {}'.format(next(limit.get_identifiers()), new_limit)
)
str_res = ''
for i in tokens:
str_res += str(i.value)
return str_res