# -*- coding: utf-8 -*- # pylint: disable=C,R,W from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import logging import sqlparse from sqlparse.sql import Identifier, IdentifierList from sqlparse.tokens import Keyword, Name RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'} ON_KEYWORD = 'ON' PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} class SupersetQuery(object): def __init__(self, sql_statement): self.sql = sql_statement self._table_names = set() self._alias_names = set() self._limit = None # TODO: multistatement support logging.info('Parsing with sqlparse statement {}'.format(self.sql)) self._parsed = sqlparse.parse(self.sql) for statement in self._parsed: self.__extract_from_token(statement) self._limit = self._extract_limit_from_query(statement) self._table_names = self._table_names - self._alias_names @property def tables(self): return self._table_names @property def limit(self): return self._limit def is_select(self): return self._parsed[0].get_type() == 'SELECT' def is_explain(self): return self.sql.strip().upper().startswith('EXPLAIN') def is_readonly(self): """Pessimistic readonly, 100% sure statement won't mutate anything""" return self.is_select() or self.is_explain() def stripped(self): return self.sql.strip(' \t\n;') @staticmethod def __precedes_table_name(token_value): for keyword in PRECEDES_TABLE_NAME: if keyword in token_value: return True return False @staticmethod def __get_full_name(identifier): if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.': return '{}.{}'.format(identifier.tokens[0].value, identifier.tokens[2].value) return identifier.get_real_name() @staticmethod def __is_result_operation(keyword): for operation in RESULT_OPERATIONS: if operation in keyword.upper(): return True return False @staticmethod def __is_identifier(token): return isinstance(token, (IdentifierList, Identifier)) def __process_identifier(self, identifier): # exclude subselects if '(' not in '{}'.format(identifier): self._table_names.add(self.__get_full_name(identifier)) return # store aliases if hasattr(identifier, 'get_alias'): self._alias_names.add(identifier.get_alias()) if hasattr(identifier, 'tokens'): # some aliases are not parsed properly if identifier.tokens[0].ttype == Name: self._alias_names.add(identifier.tokens[0].value) self.__extract_from_token(identifier) def as_create_table(self, table_name, overwrite=False): """Reformats the query into the create table as query. Works only for the single select SQL statements, in all other cases the sql query is not modified. :param superset_query: string, sql query that will be executed :param table_name: string, will contain the results of the query execution :param overwrite, boolean, table table_name will be dropped if true :return: string, create table as query """ exec_sql = '' sql = self.stripped() if overwrite: exec_sql = 'DROP TABLE IF EXISTS {table_name};\n' exec_sql += 'CREATE TABLE {table_name} AS \n{sql}' return exec_sql.format(**locals()) def __extract_from_token(self, token): if not hasattr(token, 'tokens'): return table_name_preceding_token = False for item in token.tokens: if item.is_group and not self.__is_identifier(item): self.__extract_from_token(item) if item.ttype in Keyword: if self.__precedes_table_name(item.value.upper()): table_name_preceding_token = True continue if not table_name_preceding_token: continue if item.ttype in Keyword or item.value == ',': if (self.__is_result_operation(item.value) or item.value.upper() == ON_KEYWORD): table_name_preceding_token = False continue # FROM clause is over break if isinstance(item, Identifier): self.__process_identifier(item) if isinstance(item, IdentifierList): for token in item.tokens: if self.__is_identifier(token): self.__process_identifier(token) def _get_limit_from_token(self, token): if token.ttype == sqlparse.tokens.Literal.Number.Integer: return int(token.value) elif token.is_group: return int(token.get_token_at_offset(1).value) def _extract_limit_from_query(self, statement): limit_token = None for pos, item in enumerate(statement.tokens): if item.ttype in Keyword and item.value.lower() == 'limit': limit_token = statement.tokens[pos + 2] return self._get_limit_from_token(limit_token) def get_query_with_new_limit(self, new_limit): """returns the query with the specified limit""" """does not change the underlying query""" if not self._limit: return self.sql + ' LIMIT ' + str(new_limit) limit_pos = None tokens = self._parsed[0].tokens # Add all items to before_str until there is a limit for pos, item in enumerate(tokens): if item.ttype in Keyword and item.value.lower() == 'limit': limit_pos = pos break limit = tokens[limit_pos + 2] if limit.ttype == sqlparse.tokens.Literal.Number.Integer: tokens[limit_pos + 2].value = new_limit elif limit.is_group: tokens[limit_pos + 2].value = ( '{}, {}'.format(next(limit.get_identifiers()), new_limit) ) str_res = '' for i in tokens: str_res += str(i.value) return str_res