From 7a3ed6e1bc88c02c935b6d356fe40aaa3d97673c Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 Aug 2016 20:05:18 -0700 Subject: [PATCH] Async support for the queries in the SQL Lab. (#974) * Refactor the query runner to enable async mode. * Refactore the sql calling functions into the QueryRunner class. * Clean up the celery tests. --- caravel/config.py | 10 +- caravel/extract_table_names.py | 60 +++ .../versions/ad82a75afd82_add_query_model.py | 16 +- caravel/models.py | 30 +- caravel/sql_lab.py | 153 +++++++ caravel/sql_lab_utils.py | 125 ++++++ caravel/tasks.py | 219 ---------- caravel/utils.py | 9 +- caravel/views.py | 181 +++++++- setup.py | 3 + tests/caravel_test_config.py | 5 +- tests/celery_tests.py | 397 ++++++++++-------- tests/core_tests.py | 8 +- 13 files changed, 774 insertions(+), 442 deletions(-) create mode 100644 caravel/extract_table_names.py create mode 100644 caravel/sql_lab.py create mode 100644 caravel/sql_lab_utils.py delete mode 100644 caravel/tasks.py diff --git a/caravel/config.py b/caravel/config.py index 1930ffa0521..869dd4c35f9 100644 --- a/caravel/config.py +++ b/caravel/config.py @@ -178,6 +178,7 @@ BACKUP_COUNT = 30 # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = "" + # Maximum number of rows returned in the SQL editor SQL_MAX_ROW = 1000 @@ -192,10 +193,10 @@ WARNING_MSG = None """ # Example: class CeleryConfig(object): - BROKER_URL = 'sqla+sqlite:///celerydb.sqlite' - CELERY_IMPORTS = ('caravel.tasks', ) - CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite' - CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} + BROKER_URL = 'sqla+sqlite:///celerydb.sqlite' + CELERY_IMPORTS = ('caravel.tasks', ) + CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite' + CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} CELERY_CONFIG = CeleryConfig """ CELERY_CONFIG = None @@ -207,4 +208,3 @@ except ImportError: if not CACHE_DEFAULT_TIMEOUT: CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT') - diff --git a/caravel/extract_table_names.py b/caravel/extract_table_names.py new file mode 100644 index 00000000000..4bc57074290 --- /dev/null +++ b/caravel/extract_table_names.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com +# +# This example is part of python-sqlparse and is released under +# the BSD License: http://www.opensource.org/licenses/bsd-license.php +# +# This example illustrates how to extract table names from nested +# SELECT statements. +# +# See: +# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895 + +import sqlparse +from sqlparse.sql import IdentifierList, Identifier +from sqlparse.tokens import Keyword, DML + + +def is_subselect(parsed): + if not parsed.is_group(): + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() == 'SELECT': + return True + return False + + +def extract_from_part(parsed): + from_seen = False + for item in parsed.tokens: + if from_seen: + if is_subselect(item): + for x in extract_from_part(item): + yield x + elif item.ttype is Keyword: + raise StopIteration + else: + yield item + elif item.ttype is Keyword and item.value.upper() == 'FROM': + from_seen = True + + +def extract_table_identifiers(token_stream): + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + yield identifier.get_name() + elif isinstance(item, Identifier): + yield item.get_name() + # It's a bug to check for Keyword here, but in the example + # above some tables names are identified as keywords... + elif item.ttype is Keyword: + yield item.value + + +# TODO(bkyryliuk): add logic to support joins and unions. +def extract_tables(sql): + stream = extract_from_part(sqlparse.parse(sql)[0]) + return list(extract_table_identifiers(stream)) diff --git a/caravel/migrations/versions/ad82a75afd82_add_query_model.py b/caravel/migrations/versions/ad82a75afd82_add_query_model.py index 4794f416de0..31e41afbada 100644 --- a/caravel/migrations/versions/ad82a75afd82_add_query_model.py +++ b/caravel/migrations/versions/ad82a75afd82_add_query_model.py @@ -13,17 +13,27 @@ down_revision = 'f162a1dea4c4' from alembic import op import sqlalchemy as sa + def upgrade(): op.create_table('query', sa.Column('id', sa.Integer(), nullable=False), sa.Column('database_id', sa.Integer(), nullable=False), - sa.Column('tmp_table_name', sa.String(length=64), nullable=True), + sa.Column('tmp_table_name', sa.String(length=256), nullable=True), + sa.Column('tab_name', sa.String(length=256),nullable=True), sa.Column('user_id', sa.Integer(), nullable=True), sa.Column('status', sa.String(length=16), nullable=True), - sa.Column('name', sa.String(length=64), nullable=True), - sa.Column('sql', sa.Text, nullable=True), + sa.Column('name', sa.String(length=256), nullable=True), + sa.Column('schema', sa.String(length=256), nullable=True), + sa.Column('sql', sa.Text(), nullable=True), + sa.Column('select_sql', sa.Text(), nullable=True), + sa.Column('executed_sql', sa.Text(), nullable=True), sa.Column('limit', sa.Integer(), nullable=True), + sa.Column('limit_used', sa.Boolean(), nullable=True), + sa.Column('select_as_cta', sa.Boolean(), nullable=True), + sa.Column('select_as_cta_used', sa.Boolean(), nullable=True), sa.Column('progress', sa.Integer(), nullable=True), + sa.Column('rows', sa.Integer(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), sa.Column('start_time', sa.DateTime(), nullable=True), sa.Column('end_time', sa.DateTime(), nullable=True), sa.ForeignKeyConstraint(['database_id'], [u'dbs.id'], ), diff --git a/caravel/models.py b/caravel/models.py index ca57a1f6668..404386ecbc8 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -379,7 +379,7 @@ class Database(Model, AuditMixinNullable): sqlalchemy_uri = Column(String(1024)) password = Column(EncryptedType(String(1024), config.get('SECRET_KEY'))) cache_timeout = Column(Integer) - select_as_create_table_as = Column(Boolean, default=True) + select_as_create_table_as = Column(Boolean, default=False) extra = Column(Text, default=textwrap.dedent("""\ { "metadata_params": {}, @@ -1734,6 +1734,16 @@ class FavStar(Model): class QueryStatus: + def from_presto_states(self, presto_status): + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + if presto_status.lower() == 'running': + return QueryStatus.IN_PROGRESS + SCHEDULED = 'SCHEDULED' CANCELLED = 'CANCELLED' IN_PROGRESS = 'IN_PROGRESS' @@ -1752,18 +1762,30 @@ class Query(Model): database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False) # Store the tmp table into the DB only if the user asks for it. - tmp_table_name = Column(String(64)) + tmp_table_name = Column(String(256)) user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True) # models.QueryStatus status = Column(String(16)) - name = Column(String(64)) + name = Column(String(256)) + tab_name = Column(String(256)) + schema = Column(String(256)) sql = Column(Text) - # Could be configured in the caravel config + # Query to retrieve the results, + # used only in case of select_as_cta_used is true. + select_sql = Column(Text) + executed_sql = Column(Text) + # Could be configured in the caravel config. limit = Column(Integer) + limit_used = Column(Boolean) + select_as_cta = Column(Boolean) + select_as_cta_used = Column(Boolean) # 1..100 progress = Column(Integer) + # # of rows in the result set or rows modified. + rows = Column(Integer) + error_message = Column(Text) start_time = Column(DateTime) end_time = Column(DateTime) diff --git a/caravel/sql_lab.py b/caravel/sql_lab.py new file mode 100644 index 00000000000..b7b5c3152ab --- /dev/null +++ b/caravel/sql_lab.py @@ -0,0 +1,153 @@ +import celery +from caravel import models, app, utils, sql_lab_utils +from datetime import datetime + +celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) + + +@celery_app.task +def get_sql_results(query_id): + """Executes the sql query returns the results.""" + sql_manager = QueryRunner(query_id) + sql_manager.run_sql() + # Return the result for the sync call. + # if self.request.called_directly: + if sql_manager.query().status == models.QueryStatus.FINISHED: + return { + 'query_id': sql_manager.query().id, + 'status': sql_manager.query().status, + 'data': sql_manager.data(), + 'columns': sql_manager.columns(), + } + else: + return { + 'query_id': sql_manager.query().id, + 'status': sql_manager.query().status, + 'error': sql_manager.query().error_message, + } + + +class QueryRunner: + def __init__(self, query_id): + self._query_id = query_id + # Creates a separate session, reusing the db.session leads to the + # concurrency issues. + self._session = sql_lab_utils.create_scoped_session() + self._query = self._session.query(models.Query).filter_by( + id=query_id).first() + self._db_to_query = self._session.query(models.Database).filter_by( + id=self._query.database_id).first() + # Query result. + self._data = None + self._columns = None + + def _sanity_check(self): + if not self._query: + self._query.error_message = "Query with id {0} not found.".format( + self._query_id) + if not self._db_to_query: + self._query.error_message = ( + "Database with id {0} is missing.".format( + self._query.database_id) + ) + + if self._query.error_message: + self._query.status = models.QueryStatus.FAILED + self._session.flush() + return False + return True + + def query(self): + return self._query + + def data(self): + return self._data + + def columns(self): + return self._columns + + def run_sql(self): + if not self._sanity_check(): + return self._query.status + + # TODO(bkyryliuk): dump results somewhere for the webserver. + engine = self._db_to_query.get_sqla_engine(schema=self._query.schema) + self._query.executed_sql = self._query.sql.strip().strip(';') + + # Limit enforced only for retrieving the data, not for the CTA queries. + self._query.select_as_cta_used = False + self._query.limit_used = False + if sql_lab_utils.is_query_select(self._query.sql): + if self._query.select_as_cta: + if not self._query.tmp_table_name: + self._query.tmp_table_name = 'tmp_{}_table_{}'.format( + self._query.user_id, + self._query.start_time.strftime('%Y_%m_%d_%H_%M_%S')) + self._query.executed_sql = sql_lab_utils.create_table_as( + self._query.executed_sql, self._query.tmp_table_name) + self._query.select_as_cta_used = True + elif self._query.limit: + self._query.executed_sql = sql_lab_utils.add_limit_to_the_sql( + self._query.executed_sql, self._query.limit, engine) + self._query.limit_used = True + + # TODO(bkyryliuk): ensure that tmp table was created. + # Do not set tmp table name if table wasn't created. + if not self._query.select_as_cta_used: + self._query.tmp_table_name = None + self._get_sql_results(engine) + + self._query.end_time = datetime.now() + self._session.flush() + return self._query.status + + def _get_sql_results(self, engine): + try: + result_proxy = engine.execute( + self._query.executed_sql, schema=self._query.schema) + except Exception as e: + self._query.error_message = utils.error_msg_from_exception(e) + self._query.status = models.QueryStatus.FAILED + return + + cursor = result_proxy.cursor + if hasattr(cursor, "poll"): + query_stats = cursor.poll() + self._query.status = models.QueryStatus.IN_PROGRESS + self._session.flush() + # poll returns dict -- JSON status information or ``None`` + # if the query is done + # https://github.com/dropbox/PyHive/blob/ + # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178 + while query_stats: + # Update the object and wait for the kill signal. + self._session.refresh(self._query) + completed_splits = int(query_stats['stats']['completedSplits']) + total_splits = int(query_stats['stats']['totalSplits']) + progress = 100 * completed_splits / total_splits + if progress > self._query.progress: + self._query.progress = progress + + self._session.flush() + query_stats = cursor.poll() + # TODO(b.kyryliuk): check for the kill signal. + + sql_results = sql_lab_utils.fetch_response_from_cursor( + result_proxy) + self._columns = sql_results['columns'] + self._data = sql_results['data'] + self._query.rows = result_proxy.rowcount + self._query.progress = 100 + self._query.status = models.QueryStatus.FINISHED + if self._query.rows == -1 and self._data: + # Presto doesn't provide result_proxy.row_count + self._query.rows = len(self._data) + + # CTAs queries result in 1 cell having the # of the added rows. + if self._query.select_as_cta_used: + self._query.select_sql = sql_lab_utils.select_star( + engine, self._query.tmp_table_name, self._query.limit) + else: + self._query.tmp_table = None + + diff --git a/caravel/sql_lab_utils.py b/caravel/sql_lab_utils.py new file mode 100644 index 00000000000..085a953667a --- /dev/null +++ b/caravel/sql_lab_utils.py @@ -0,0 +1,125 @@ +# SQL Lab Utils +import pandas as pd + +import sqlparse +from caravel import models, app +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy import select, text +from sqlalchemy.sql.expression import TextAsFrom + + +def create_scoped_session(): + """Creates new SQLAlchemy scoped_session.""" + engine = create_engine( + app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True) + return scoped_session(sessionmaker( + autocommit=True, autoflush=False, bind=engine)) + + +def fetch_response_from_cursor(result_proxy): + columns = None + data = None + if result_proxy.cursor: + cols = [col[0] for col in result_proxy.cursor.description] + data = result_proxy.fetchall() + df = pd.DataFrame(data, columns=cols) + df = df.fillna(0) + columns = [c for c in df.columns] + data = df.to_dict(orient='records') + return { + 'columns': columns, + 'data': data, + } + + +def is_query_select(sql): + try: + return sqlparse.parse(sql)[0].get_type() == 'SELECT' + # Capture sqlparse exceptions, worker shouldn't fail here. + except Exception: + # TODO(bkyryliuk): add logging here. + return False + + +# if sqlparse provides the stream of tokens but don't provide the API +# to access the table names, more on it: +# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU +# https://github.com/andialbrecht/sqlparse/blob/master/examples/ +# extract_table_names.py +# +# Another approach would be to run the EXPLAIN on the sql statement: +# https://prestodb.io/docs/current/sql/explain.html +# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain +def get_tables(): + """Retrieves the query names from the query.""" + # TODO(bkyryliuk): implement parsing the sql statement. + pass + + +def select_star(engine, table_name, limit): + if limit: + select_star_sql = select('*').select_from(table_name).limit(limit) + else: + select_star_sql = select('*').select_from(table_name) + + # SQL code to preview the results + return '{}'.format(select_star_sql.compile( + engine, compile_kwargs={"literal_binds": True})) + + +def add_limit_to_the_sql(sql, limit, eng): + # Treat as single sql statement in case of failure. + try: + sql_statements = [s for s in sqlparse.split(sql) if s] + except Exception as e: + app.logger.info( + "Statement " + sql + "failed to be transformed to have the limit " + + + "with the exception" + e.message) + return sql + if len(sql_statements) == 1 and is_query_select(sql): + qry = select('*').select_from( + TextAsFrom(text(sql_statements[0]), ['*']).alias( + 'inner_qry')).limit(limit) + sql_statement = str(qry.compile( + eng, compile_kwargs={"literal_binds": True})) + return sql_statement + return sql + + +# create table works only for the single statement. +# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it +# for the CTA operation. +def create_table_as(sql, table_name, override=False): + """Reformats the query into the create table as query. + + Works only for the single select SQL statements, in all other cases + the sql query is not modified. + :param sql: string, sql query that will be executed + :param table_name: string, will contain the results of the query execution + :param override, boolean, table table_name will be dropped if true + :return: string, create table as query + """ + # TODO(bkyryliuk): drop table if allowed, check the namespace and + # the permissions. + # Treat as single sql statement in case of failure. + try: + # Filter out empty statements. + sql_statements = [s for s in sqlparse.split(sql) if s] + except Exception as e: + app.logger.info( + "Statement " + sql + "failed to be transformed as create table as " + "with the exception" + e.message) + return sql + if len(sql_statements) == 1 and is_query_select(sql): + updated_sql = '' + # TODO(bkyryliuk): use sqlalchemy statements for the + # the drop and create operations. + if override: + updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name) + updated_sql += "CREATE TABLE %s AS %s" % ( + table_name, sql_statements[0]) + return updated_sql + return sql diff --git a/caravel/tasks.py b/caravel/tasks.py deleted file mode 100644 index c48e6699745..00000000000 --- a/caravel/tasks.py +++ /dev/null @@ -1,219 +0,0 @@ -import celery -from caravel import models, app, utils -from datetime import datetime -import logging -from sqlalchemy import create_engine, select, text -from sqlalchemy.orm import scoped_session, sessionmaker -from sqlalchemy.sql.expression import TextAsFrom -import sqlparse -import pandas as pd - -celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) - - -def is_query_select(sql): - try: - return sqlparse.parse(sql)[0].get_type() == 'SELECT' - # Capture sqlparse exceptions, worker shouldn't fail here. - except Exception: - # TODO(bkyryliuk): add logging here. - return False - - -# if sqlparse provides the stream of tokens but don't provide the API -# to access the table names, more on it: -# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU -# https://github.com/andialbrecht/sqlparse/blob/master/examples/ -# extract_table_names.py -# -# Another approach would be to run the EXPLAIN on the sql statement: -# https://prestodb.io/docs/current/sql/explain.html -# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain -def get_tables(): - """Retrieves the query names from the query.""" - # TODO(bkyryliuk): implement parsing the sql statement. - pass - - -def add_limit_to_the_query(sql, limit, eng): - # Treat as single sql statement in case of failure. - sql_statements = [sql] - try: - sql_statements = [s for s in sqlparse.split(sql) if s] - except Exception as e: - logging.info( - "Statement " + sql + "failed to be transformed to have the limit " - "with the exception" + e.message) - return sql - if len(sql_statements) == 1 and is_query_select(sql): - qry = select('*').select_from( - TextAsFrom(text(sql_statements[0]), ['*']).alias( - 'inner_qry')).limit(limit) - sql_statement = str(qry.compile( - eng, compile_kwargs={"literal_binds": True})) - return sql_statement - return sql - - -# create table works only for the single statement. -def create_table_as(sql, table_name, override=False): - """Reformats the query into the create table as query. - - Works only for the single select SQL statements, in all other cases - the sql query is not modified. - :param sql: string, sql query that will be executed - :param table_name: string, will contain the results of the query execution - :param override, boolean, table table_name will be dropped if true - :return: string, create table as query - """ - # TODO(bkyryliuk): drop table if allowed, check the namespace and - # the permissions. - # Treat as single sql statement in case of failure. - sql_statements = [sql] - try: - # Filter out empty statements. - sql_statements = [s for s in sqlparse.split(sql) if s] - except Exception as e: - logging.info( - "Statement " + sql + "failed to be transformed as create table as " - "with the exception" + e.message) - return sql - if len(sql_statements) == 1 and is_query_select(sql): - updated_sql = '' - # TODO(bkyryliuk): use sqlalchemy statements for the - # the drop and create operations. - if override: - updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name) - updated_sql += "CREATE TABLE %s AS %s" % ( - table_name, sql_statements[0]) - return updated_sql - return sql - - -def get_session(): - """Creates new SQLAlchemy scoped_session.""" - engine = create_engine( - app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True) - return scoped_session(sessionmaker( - autocommit=False, autoflush=False, bind=engine)) - - -@celery_app.task -def get_sql_results(database_id, sql, user_id, tmp_table_name="", schema=None): - """Executes the sql query returns the results. - - :param database_id: integer - :param sql: string, query that will be executed - :param user_id: integer - :param tmp_table_name: name of the table for CTA - :param schema: string, name of the schema (used in presto) - :return: dataframe, query result - """ - # Create a separate session, reusing the db.session leads to the - # concurrency issues. - session = get_session() - try: - db_to_query = ( - session.query(models.Database).filter_by(id=database_id).first() - ) - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - if not db_to_query: - return { - 'error': "Database with id {0} is missing.".format(database_id), - 'success': False, - } - - # TODO(bkyryliuk): provide a way for the user to name the query. - # TODO(bkyryliuk): run explain query to derive the tables and fill in the - # table_ids - # TODO(bkyryliuk): check the user permissions - # TODO(bkyryliuk): store the tab name in the query model - limit = app.config.get('SQL_MAX_ROW', None) - start_time = datetime.now() - if not tmp_table_name: - tmp_table_name = 'tmp.{}_table_{}'.format(user_id, start_time) - query = models.Query( - user_id=user_id, - database_id=database_id, - limit=limit, - name='{}'.format(start_time), - sql=sql, - start_time=start_time, - tmp_table_name=tmp_table_name, - status=models.QueryStatus.IN_PROGRESS, - ) - session.add(query) - session.commit() - query_result = get_sql_results_as_dict( - db_to_query, sql, query.tmp_table_name, schema=schema) - query.end_time = datetime.now() - if query_result['success']: - query.status = models.QueryStatus.FINISHED - else: - query.status = models.QueryStatus.FAILED - session.commit() - # TODO(bkyryliuk): return the tmp table / query_id - return query_result - - -# TODO(bkyryliuk): merge the changes made in the carapal first -# before merging this PR. -def get_sql_results_as_dict(db_to_query, sql, tmp_table_name, schema=None): - """Get the SQL query results from the give session and db connection. - - :param sql: string, query that will be executed - :param db_to_query: models.Database to query, cannot be None - :param tmp_table_name: name of the table for CTA - :param schema: string, name of the schema (used in presto) - :return: (dataframe, boolean), results and the status - """ - eng = db_to_query.get_sqla_engine(schema=schema) - sql = sql.strip().strip(';') - # TODO(bkyryliuk): fix this case for multiple statements - if app.config.get('SQL_MAX_ROW'): - sql = add_limit_to_the_query( - sql, app.config.get("SQL_MAX_ROW"), eng) - - cta_used = False - if (app.config.get('SQL_SELECT_AS_CTA') and - db_to_query.select_as_create_table_as and is_query_select(sql)): - # TODO(bkyryliuk): figure out if the query is select query. - sql = create_table_as(sql, tmp_table_name) - cta_used = True - - if cta_used: - try: - eng.execute(sql) - return { - 'tmp_table': tmp_table_name, - 'success': True, - } - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - - # otherwise run regular SQL query. - # TODO(bkyryliuk): rewrite into eng.execute as queries different from - # select should be permitted too. - try: - df = db_to_query.get_df(sql, schema) - df = df.fillna(0) - return { - 'columns': [c for c in df.columns], - 'data': df.to_dict(orient='records'), - 'success': True, - } - - except Exception as e: - return { - 'error': utils.error_msg_from_exception(e), - 'success': False, - } - - diff --git a/caravel/utils.py b/caravel/utils.py index 9b784517c57..668c80f4936 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -339,10 +339,17 @@ def error_msg_from_exception(e): Database have different ways to handle exception. This function attempts to make sense of the exception object and construct a human readable sentence. + + TODO(bkyryliuk): parse the Presto error message from the connection + created via create_engine. + engine = create_engine('presto://localhost:3506/silver') - + gives an e.message as the str(dict) + presto.connect("localhost", port=3506, catalog='silver') - as a dict. + The latter version is parsed correctly by this function. """ msg = '' if hasattr(e, 'message'): - if (type(e.message) is dict): + if type(e.message) is dict: msg = e.message.get('message') elif e.message: msg = "{}".format(e.message) diff --git a/caravel/views.py b/caravel/views.py index 5bd50c554c9..2fdc418946e 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -33,7 +33,7 @@ from wtforms.validators import ValidationError import caravel from caravel import ( - appbuilder, db, models, viz, utils, app, sm, ascii_art, tasks + appbuilder, db, models, viz, utils, app, sm, ascii_art, sql_lab ) config = app.config @@ -459,11 +459,13 @@ class DatabaseAsync(DatabaseView): appbuilder.add_view_no_menu(DatabaseAsync) + class DatabaseTablesAsync(DatabaseView): list_columns = ['id', 'all_table_names', 'all_schema_names'] appbuilder.add_view_no_menu(DatabaseTablesAsync) + class TableModelView(CaravelModelView, DeleteMixin): # noqa datamodel = SQLAInterface(models.SqlaTable) list_columns = [ @@ -622,7 +624,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa url = "/druiddatasourcemodelview/list/" msg = _( "Click on a datasource link to create a Slice, " - "or click on a table link here " + "or click on a table link " + "here " "to create a Slice for a table" ) else: @@ -904,7 +907,8 @@ class Caravel(BaseCaravelView): datasource_access = self.can_access( 'datasource_access', datasource.perm) if not (all_datasource_access or datasource_access): - flash(__("You don't seem to have access to this datasource"), "danger") + flash(__("You don't seem to have access to this datasource"), + "danger") return redirect(error_redirect) action = request.args.get('action') @@ -981,7 +985,8 @@ class Caravel(BaseCaravelView): del d['action'] del d['previous_viz_type'] - as_list = ('metrics', 'groupby', 'columns', 'all_columns', 'mapbox_label', 'order_by_cols') + as_list = ('metrics', 'groupby', 'columns', 'all_columns', + 'mapbox_label', 'order_by_cols') for k in d: v = d.get(k) if k in as_list and not isinstance(v, list): @@ -1092,7 +1097,8 @@ class Caravel(BaseCaravelView): .group_by(Log.dt) .all() ) - payload = {str(time.mktime(dt.timetuple())): ccount for dt, ccount in qry if dt} + payload = {str(time.mktime(dt.timetuple())): + ccount for dt, ccount in qry if dt} return Response(json.dumps(payload), mimetype="application/json") @api @@ -1148,9 +1154,11 @@ class Caravel(BaseCaravelView): data = json.loads(request.form.get('data')) session = db.session() Slice = models.Slice # noqa - dash = session.query(models.Dashboard).filter_by(id=dashboard_id).first() + dash = ( + session.query(models.Dashboard).filter_by(id=dashboard_id).first()) check_ownership(dash, raise_if_false=True) - new_slices = session.query(Slice).filter(Slice.id.in_(data['slice_ids'])) + new_slices = session.query(Slice).filter( + Slice.id.in_(data['slice_ids'])) dash.slices += new_slices session.merge(dash) session.commit() @@ -1184,13 +1192,18 @@ class Caravel(BaseCaravelView): FavStar = models.FavStar # noqa count = 0 favs = session.query(FavStar).filter_by( - class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()).all() + class_name=class_name, obj_id=obj_id, + user_id=g.user.get_id()).all() if action == 'select': if not favs: session.add( FavStar( - class_name=class_name, obj_id=obj_id, user_id=g.user.get_id(), - dttm=datetime.now())) + class_name=class_name, + obj_id=obj_id, + user_id=g.user.get_id(), + dttm=datetime.now() + ) + ) count = 1 elif action == 'unselect': for fav in favs: @@ -1396,9 +1409,24 @@ class Caravel(BaseCaravelView): sql = request.form.get('sql') database_id = request.form.get('database_id') schema = request.form.get('schema') + tab_name = request.form.get('tab_name') + + async = request.form.get('async') == 'True' + tmp_table_name = request.form.get('tmp_table_name', None) + select_as_cta = request.form.get('select_as_cta') == 'True' + session = db.session() mydb = session.query(models.Database).filter_by(id=database_id).first() + if not mydb: + return Response( + json.dumps({ + 'error': 'Database with id 0 is missing.', + 'status': models.QueryStatus.FAILED, + }), + status=500, + mimetype="application/json") + if not (self.can_access( 'all_datasource_access', 'all_datasource_access') or self.can_access('database_access', mydb.perm)): @@ -1406,19 +1434,132 @@ class Caravel(BaseCaravelView): "SQL Lab requires the `all_datasource_access` or " "specific DB permission")) - data = tasks.get_sql_results(database_id, sql, g.user.get_id(), - schema=schema) - if 'error' in data: + start_time = datetime.now() + query_name = '{}_{}_{}'.format( + g.user.get_id(), tab_name, start_time.strftime('%M:%S:%f')) + + query = models.Query( + database_id=database_id, + limit=app.config.get('SQL_MAX_ROW', None), + name=query_name, + sql=sql, + schema=schema, + # TODO(bkyryliuk): consider it being DB property. + select_as_cta=select_as_cta, + start_time=start_time, + status=models.QueryStatus.SCHEDULED, + tab_name=tab_name, + tmp_table_name=tmp_table_name, + user_id=g.user.get_id(), + ) + session.add(query) + session.commit() + + # Async request. + if async: + # Ignore the celery future object and the request may time out. + sql_lab.get_sql_results.delay(query.id) + return Response(json.dumps( + { + 'query_id': query.id, + 'status': query.status, + }, + default=utils.json_int_dttm_ser, allow_nan=False), + status=202, # Accepted + mimetype="application/json") + + # Sync request. + data = sql_lab.get_sql_results(query.id) + if data['status'] == models.QueryStatus.FAILED: return Response( - json.dumps(data), + json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False), status=500, mimetype="application/json") - if 'tmp_table' in data: - # TODO(bkyryliuk): add query id to the response and implement the - # endpoint to poll the status and results. - return None - return json.dumps( - data, default=utils.json_int_dttm_ser, allow_nan=False) + return Response( + json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False), + status=200, + mimetype="application/json") + + @has_access + @expose("/queries/", methods=['GET']) + @log_this + def queries(self): + """Runs arbitrary sql and returns and json""" + last_updated = request.form.get('timestamp') + s = db.session() + query = s.query(models.Query).filter_by(id=query_id).first() + mydb = s.query(models.Database).filter_by(id=query.database_id).first() + + if not (self.can_access( + 'all_datasource_access', 'all_datasource_access') or + self.can_access('database_access', mydb.perm)): + raise utils.CaravelSecurityException(_( + "SQL Lab requires the `all_datasource_access` or " + "specific DB permission")) + + if query: + return Response( + json.dumps({ + 'status': query.status, + 'progress': query.progress + }), + status=200, + mimetype="application/json") + + return Response( + json.dumps({ + 'error': "Query with id {} wasn't found".format(query_id), + }), + status=404, + mimetype="application/json") + + @has_access + @expose("/cta_query_results/", methods=['GET']) + @log_this + def cta_query_results(self): + """Runs arbitrary sql and returns and json""" + query_id = request.form.get('query_id') + s = db.session() + query = s.query(models.Query).filter_by(id=query_id).first() + mydb = s.query(models.Database).filter_by(id=query.database_id).first() + + if not (self.can_access( + 'all_datasource_access', 'all_datasource_access') or + self.can_access('database_access', mydb.perm)): + raise utils.CaravelSecurityException(_( + "SQL Lab requires the `all_datasource_access` or " + "specific DB permission")) + + if not query: + return Response( + json.dumps({ + 'error': "Query with id {} wasn't found".format(query_id), + }), + status=404, + mimetype="application/json") + + if query.status != models.QueryStatus.FINISHED: + return Response( + json.dumps({ + 'error': "Query with id {} not finished yet".format( + query_id), + }), + status=400, + mimetype="application/json") + try: + data = mydb.get_df(query.select_sql, query.schema) + return Response( + json.dumps( + data, default=utils.json_int_dttm_ser, allow_nan=False), + status=200, + mimetype="application/json") + except Exception as e: + return Response( + json.dumps('error', utils.error_msg_from_exception(e)), + status=500, + mimetype="application/json") @has_access @expose("/refresh_datasources/") diff --git a/setup.py b/setup.py index b90bc599afd..fd3f6a987eb 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ setup( 'pandas==0.18.1', 'parsedatetime==2.0.0', 'pydruid==0.3.0', + 'PyHive>=0.2.1', 'python-dateutil==2.5.3', 'requests==2.10.0', 'simplejson==3.8.2', @@ -37,6 +38,8 @@ setup( 'sqlalchemy==1.0.13', 'sqlalchemy-utils==0.32.7', 'sqlparse==0.1.19', + 'thrift>=0.9.3', + 'thrift-sasl>=0.2.1', 'werkzeug==0.11.10', ], extras_require={ diff --git a/tests/caravel_test_config.py b/tests/caravel_test_config.py index a28240a87f7..1c24ee640bb 100644 --- a/tests/caravel_test_config.py +++ b/tests/caravel_test_config.py @@ -13,12 +13,13 @@ if 'CARAVEL__SQLALCHEMY_DATABASE_URI' in os.environ: SQL_CELERY_DB_FILE_PATH = os.path.join(DATA_DIR, 'celerydb.sqlite') SQL_CELERY_RESULTS_DB_FILE_PATH = os.path.join(DATA_DIR, 'celery_results.sqlite') SQL_SELECT_AS_CTA = True +SQL_MAX_ROW = 666 class CeleryConfig(object): BROKER_URL = 'sqla+sqlite:///' + SQL_CELERY_DB_FILE_PATH - CELERY_IMPORTS = ('caravel.tasks', ) + CELERY_IMPORTS = ('caravel.sql_lab', ) CELERY_RESULT_BACKEND = 'db+sqlite:///' + SQL_CELERY_RESULTS_DB_FILE_PATH - CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} + CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}} CONCURRENCY = 1 CELERY_CONFIG = CeleryConfig diff --git a/tests/celery_tests.py b/tests/celery_tests.py index e88ae0fca1c..48a2c004cfe 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -1,36 +1,47 @@ """Unit tests for Caravel Celery worker""" -import datetime import imp +import json import subprocess -import os -import pandas as pd + import time + +import os + +import pandas as pd import unittest import caravel -from caravel import app, appbuilder, db, models, tasks, utils +from caravel import app, appbuilder, db, models, sql_lab, sql_lab_utils, utils -class CeleryConfig(object): - BROKER_URL = 'sqla+sqlite:////tmp/celerydb.sqlite' - CELERY_IMPORTS = ('caravel.tasks',) - CELERY_RESULT_BACKEND = 'db+sqlite:////tmp/celery_results.sqlite' - CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}} -app.config['CELERY_CONFIG'] = CeleryConfig - BASE_DIR = app.config.get('BASE_DIR') cli = imp.load_source('cli', BASE_DIR + '/bin/caravel') +SQL_CELERY_DB_FILE_PATH = '/tmp/celerydb.sqlite' +SQL_CELERY_RESULTS_DB_FILE_PATH = '/tmp/celery_results.sqlite' + + +class CeleryConfig(object): + BROKER_URL = 'sqla+sqlite:///' + SQL_CELERY_DB_FILE_PATH + CELERY_IMPORTS = ('caravel.sql_lab', ) + CELERY_RESULT_BACKEND = 'db+sqlite:///' + SQL_CELERY_RESULTS_DB_FILE_PATH + CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}} + CONCURRENCY = 1 +app.config['CELERY_CONFIG'] = CeleryConfig + +# TODO(bkyryliuk): add ability to run this test separately. + class UtilityFunctionTests(unittest.TestCase): def test_create_table_as(self): select_query = "SELECT * FROM outer_space;" - updated_select_query = tasks.create_table_as(select_query, "tmp") + updated_select_query = sql_lab_utils.create_table_as( + select_query, "tmp") self.assertEqual( "CREATE TABLE tmp AS SELECT * FROM outer_space;", updated_select_query) - updated_select_query_with_drop = tasks.create_table_as( + updated_select_query_with_drop = sql_lab_utils.create_table_as( select_query, "tmp", override=True) self.assertEqual( "DROP TABLE IF EXISTS tmp;\n" @@ -38,24 +49,26 @@ class UtilityFunctionTests(unittest.TestCase): updated_select_query_with_drop) select_query_no_semicolon = "SELECT * FROM outer_space" - updated_select_query_no_semicolon = tasks.create_table_as( + updated_select_query_no_semicolon = sql_lab_utils.create_table_as( select_query_no_semicolon, "tmp") self.assertEqual( "CREATE TABLE tmp AS SELECT * FROM outer_space", updated_select_query_no_semicolon) incorrect_query = "SMTH WRONG SELECT * FROM outer_space" - updated_incorrect_query = tasks.create_table_as(incorrect_query, "tmp") + updated_incorrect_query = sql_lab_utils.create_table_as( + incorrect_query, "tmp") self.assertEqual(incorrect_query, updated_incorrect_query) insert_query = "INSERT INTO stomach VALUES (beer, chips);" - updated_insert_query = tasks.create_table_as(insert_query, "tmp") + updated_insert_query = sql_lab_utils.create_table_as( + insert_query, "tmp") self.assertEqual(insert_query, updated_insert_query) multi_line_query = ( "SELECT * FROM planets WHERE\n" "Luke_Father = 'Darth Vader';") - updated_multi_line_query = tasks.create_table_as( + updated_multi_line_query = sql_lab_utils.create_table_as( multi_line_query, "tmp") expected_updated_multi_line_query = ( "CREATE TABLE tmp AS SELECT * FROM planets WHERE\n" @@ -64,7 +77,7 @@ class UtilityFunctionTests(unittest.TestCase): expected_updated_multi_line_query, updated_multi_line_query) - updated_multi_line_query_with_drop = tasks.create_table_as( + updated_multi_line_query_with_drop = sql_lab_utils.create_table_as( multi_line_query, "tmp", override=True) expected_updated_multi_line_query_with_drop = ( "DROP TABLE IF EXISTS tmp;\n" @@ -75,12 +88,12 @@ class UtilityFunctionTests(unittest.TestCase): updated_multi_line_query_with_drop) delete_query = "DELETE FROM planet WHERE name = 'Earth'" - updated_delete_query = tasks.create_table_as(delete_query, "tmp") + updated_delete_query = sql_lab_utils.create_table_as(delete_query, "tmp") self.assertEqual(delete_query, updated_delete_query) create_table_as = ( "CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n") - updated_create_table_as = tasks.create_table_as( + updated_create_table_as = sql_lab_utils.create_table_as( create_table_as, "tmp") self.assertEqual(create_table_as, updated_create_table_as) @@ -97,7 +110,7 @@ class UtilityFunctionTests(unittest.TestCase): "(B.TECH ,BE ,Degree ,MCA ,MiBA)\n " "AND Having Brothers= Null AND Sisters =Null" ) - updated_sql_procedure = tasks.create_table_as(sql_procedure, "tmp") + updated_sql_procedure = sql_lab_utils.create_table_as(sql_procedure, "tmp") self.assertEqual(sql_procedure, updated_sql_procedure) multiple_statements = """ @@ -107,7 +120,7 @@ class UtilityFunctionTests(unittest.TestCase): SELECT standard_disclaimer, witty_remark FROM company_requirements; select count(*) from developer_brain; """ - updated_multiple_statements = tasks.create_table_as( + updated_multiple_statements = sql_lab_utils.create_table_as( multiple_statements, "tmp") self.assertEqual(multiple_statements, updated_multiple_statements) @@ -116,36 +129,49 @@ class CeleryTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(CeleryTestCase, self).__init__(*args, **kwargs) self.client = app.test_client() + + def get_query_by_name(self, sql): + session = db.create_scoped_session() + query = session.query(models.Query).filter_by(sql=sql).first() + session.close() + return query + + def get_query_by_id(self, id): + session = db.create_scoped_session() + query = session.query(models.Query).filter_by(id=id).first() + session.close() + return query + + @classmethod + def setUpClass(cls): + try: + os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH')) + except OSError as e: + app.logger.warn(str(e)) + try: + os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) + except OSError as e: + app.logger.warn(str(e)) + utils.init(caravel) + + worker_command = BASE_DIR + '/bin/caravel worker' + subprocess.Popen( + worker_command, shell=True, stdout=subprocess.PIPE) + admin = appbuilder.sm.find_user('admin') if not admin: appbuilder.sm.add_user( 'admin', 'admin', ' user', 'admin@fab.org', appbuilder.sm.find_role('Admin'), password='general') - utils.init(caravel) - - @classmethod - def setUpClass(cls): - try: - os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH')) - except OSError: - pass - try: - os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) - except OSError: - pass - - worker_command = BASE_DIR + '/bin/caravel worker' - subprocess.Popen( - worker_command, shell=True, stdout=subprocess.PIPE) cli.load_examples(load_test_data=True) + @classmethod def tearDownClass(cls): subprocess.call( - "ps auxww | grep 'celeryd' | awk '{print $2}' | " - "xargs kill -9", + "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True ) subprocess.call( @@ -160,14 +186,40 @@ class CeleryTestCase(unittest.TestCase): def tearDown(self): pass + def login(self, username='admin', password='general'): + resp = self.client.post( + '/login/', + data=dict(username=username, password=password), + follow_redirects=True) + assert 'Welcome' in resp.data.decode('utf-8') + + def logout(self): + self.client.get('/logout/', follow_redirects=True) + + def run_sql(self, dbid, sql, cta='False', tmp_table='tmp', + async='False'): + self.login() + resp = self.client.post( + '/caravel/sql_json/', + data=dict( + database_id=dbid, + sql=sql, + async=async, + select_as_cta=cta, + tmp_table_name=tmp_table, + ), + ) + self.logout() + return json.loads(resp.data.decode('utf-8')) + def test_add_limit_to_the_query(self): - query_session = tasks.get_session() + query_session = sql_lab_utils.create_scoped_session() db_to_query = query_session.query(models.Database).filter_by( id=1).first() eng = db_to_query.get_sqla_engine() select_query = "SELECT * FROM outer_space;" - updated_select_query = tasks.add_limit_to_the_query( + updated_select_query = sql_lab_utils.add_limit_to_the_sql( select_query, 100, eng) # Different DB engines have their own spacing while compiling # the queries, that's why ' '.join(query.split()) is used. @@ -178,7 +230,7 @@ class CeleryTestCase(unittest.TestCase): ) select_query_no_semicolon = "SELECT * FROM outer_space" - updated_select_query_no_semicolon = tasks.add_limit_to_the_query( + updated_select_query_no_semicolon = sql_lab_utils.add_limit_to_the_sql( select_query_no_semicolon, 100, eng) self.assertTrue( "SELECT * FROM (SELECT * FROM outer_space) AS inner_qry " @@ -187,19 +239,19 @@ class CeleryTestCase(unittest.TestCase): ) incorrect_query = "SMTH WRONG SELECT * FROM outer_space" - updated_incorrect_query = tasks.add_limit_to_the_query( + updated_incorrect_query = sql_lab_utils.add_limit_to_the_sql( incorrect_query, 100, eng) self.assertEqual(incorrect_query, updated_incorrect_query) insert_query = "INSERT INTO stomach VALUES (beer, chips);" - updated_insert_query = tasks.add_limit_to_the_query( + updated_insert_query = sql_lab_utils.add_limit_to_the_sql( insert_query, 100, eng) self.assertEqual(insert_query, updated_insert_query) multi_line_query = ( "SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';" ) - updated_multi_line_query = tasks.add_limit_to_the_query( + updated_multi_line_query = sql_lab_utils.add_limit_to_the_sql( multi_line_query, 100, eng) self.assertTrue( "SELECT * FROM (SELECT * FROM planets WHERE " @@ -208,13 +260,13 @@ class CeleryTestCase(unittest.TestCase): ) delete_query = "DELETE FROM planet WHERE name = 'Earth'" - updated_delete_query = tasks.add_limit_to_the_query( + updated_delete_query = sql_lab_utils.add_limit_to_the_sql( delete_query, 100, eng) self.assertEqual(delete_query, updated_delete_query) create_table_as = ( "CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n") - updated_create_table_as = tasks.add_limit_to_the_query( + updated_create_table_as = sql_lab_utils.add_limit_to_the_sql( create_table_as, 100, eng) self.assertEqual(create_table_as, updated_create_table_as) @@ -231,168 +283,145 @@ class CeleryTestCase(unittest.TestCase): "(B.TECH ,BE ,Degree ,MCA ,MiBA)\n " "AND Having Brothers= Null AND Sisters = Null" ) - updated_sql_procedure = tasks.add_limit_to_the_query( + updated_sql_procedure = sql_lab_utils.add_limit_to_the_sql( sql_procedure, 100, eng) self.assertEqual(sql_procedure, updated_sql_procedure) - def test_run_async_query_delay_get(self): + def test_run_sync_query(self): main_db = db.session.query(models.Database).filter_by( database_name="main").first() eng = main_db.get_sqla_engine() # Case 1. # DB #0 doesn't exist. - result1 = tasks.get_sql_results.delay( - 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_1').get() - expected_result1 = { - 'error': 'Database with id 0 is missing.', - 'success': False - } - self.assertEqual( - sorted(expected_result1.items()), - sorted(result1.items()) - ) - session1 = db.create_scoped_session() - query1 = session1.query(models.Query).filter_by( - sql='SELECT * FROM dontexist').first() - session1.close() - self.assertIsNone(query1) + sql_dont_exist = 'SELECT * FROM dontexist' + result1 = self.run_sql(0, sql_dont_exist, cta='True') + self.assertEqual(models.QueryStatus.FAILED, result1[u'status']) + self.assertFalse(u'query_id' in result1) + self.assertEqual('Database with id 0 is missing.', result1['error']) + self.assertIsNone(self.get_query_by_name(sql_dont_exist)) # Case 2. - session2 = db.create_scoped_session() - query2 = session2.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - session2.close() - - result2 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_1').get() + # Table doesn't exist. + result2 = self.run_sql(1, sql_dont_exist, cta='True', ) self.assertTrue('error' in result2) - session2 = db.create_scoped_session() - query2 = session2.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() + self.assertEqual(models.QueryStatus.FAILED, result1[u'status']) + query2 = self.get_query_by_id(result2[u'query_id']) self.assertEqual(models.QueryStatus.FAILED, query2.status) - session2.close() # Case 3. - where_query = ( - "SELECT name FROM ab_permission WHERE name='can_select_star'") - result3 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_3_1').get() - expected_result3 = { - 'tmp_table': 'tmp_3_1', - 'success': True - } - self.assertEqual( - sorted(expected_result3.items()), - sorted(result3.items()) - ) - session3 = db.create_scoped_session() - query3 = session3.query(models.Query).filter_by( - sql=where_query).first() - session3.close() - df3 = pd.read_sql_query(sql="SELECT * FROM tmp_3_1", con=eng) + # Table and DB exists, CTA call to the backend. + sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'" + result3 = self.run_sql( + 1, sql_where, tmp_table='tmp_table_3', cta='True') + self.assertEqual(models.QueryStatus.FINISHED, result3[u'status']) + self.assertIsNone(result3[u'data']) + self.assertIsNone(result3[u'columns']) + query3 = self.get_query_by_id(result3[u'query_id']) + + # Check the data in the tmp table. + df3 = pd.read_sql_query(sql=query3.select_sql, con=eng) data3 = df3.to_dict(orient='records') - self.assertEqual(models.QueryStatus.FINISHED, query3.status) - self.assertEqual([{'name': 'can_select_star'}], data3) + self.assertEqual([{'name': 'can_sql'}], data3) # Case 4. - result4 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM ab_permission WHERE id=666', 1, - tmp_table_name='tmp_4_1').get() - expected_result4 = { - 'tmp_table': 'tmp_4_1', - 'success': True - } - self.assertEqual( - sorted(expected_result4.items()), - sorted(result4.items()) - ) - session4 = db.create_scoped_session() - query4 = session4.query(models.Query).filter_by( - sql='SELECT * FROM ab_permission WHERE id=666').first() - session4.close() - df4 = pd.read_sql_query(sql="SELECT * FROM tmp_4_1", con=eng) - data4 = df4.to_dict(orient='records') + # Table and DB exists, CTA call to the backend, no data. + sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' + result4 = self.run_sql( + 1, sql_empty_result, tmp_table='tmp_table_4', cta='True',) + self.assertEqual(models.QueryStatus.FINISHED, result4[u'status']) + self.assertIsNone(result4[u'data']) + self.assertIsNone(result4[u'columns']) + + query4 = self.get_query_by_id(result4[u'query_id']) self.assertEqual(models.QueryStatus.FINISHED, query4.status) + self.assertTrue("SELECT * \nFROM tmp_table_4" in query4.select_sql) + self.assertTrue("LIMIT 666" in query4.select_sql) + self.assertEqual( + "CREATE TABLE tmp_table_4 AS SELECT * FROM ab_user WHERE id=666", + query4.executed_sql) + self.assertEqual("SELECT * FROM ab_user WHERE id=666", query4.sql) + if eng.name != 'sqlite': + self.assertEqual(0, query4.rows) + self.assertEqual(666, query4.limit) + self.assertEqual(False, query4.limit_used) + self.assertEqual(True, query4.select_as_cta) + self.assertEqual(True, query4.select_as_cta_used) + + # Check the data in the tmp table. + df4 = pd.read_sql_query(sql=query4.select_sql, con=eng) + data4 = df4.to_dict(orient='records') self.assertEqual([], data4) # Case 5. - # Return the data directly if DB select_as_create_table_as is False. - main_db.select_as_create_table_as = False - db.session.commit() - result5 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_5_1').get() - expected_result5 = { - 'columns': ['name'], - 'data': [{'name': 'can_select_star'}], - 'success': True - } - self.assertEqual( - sorted(expected_result5.items()), - sorted(result5.items()) - ) + # Table and DB exists, select without CTA. + result5 = self.run_sql(1, sql_where, tmp_table='tmp_table_5') + self.assertEqual(models.QueryStatus.FINISHED, result5[u'status']) + self.assertEqual([u'name'], result5[u'columns']) + self.assertEqual([{u'name': u'can_sql'}], result5[u'data']) - def test_run_async_query_delay(self): - celery_task1 = tasks.get_sql_results.delay( - 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_2') - celery_task2 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_2') - where_query = ( - "SELECT name FROM ab_permission WHERE name='can_select_star'") - celery_task3 = tasks.get_sql_results.delay( - 1, where_query, 1, tmp_table_name='tmp_3_2') - celery_task4 = tasks.get_sql_results.delay( - 1, 'SELECT * FROM ab_permission WHERE id=666', 1, - tmp_table_name='tmp_4_2') + query5 = self.get_query_by_id(result5[u'query_id']) + self.assertEqual(sql_where, query5.sql) + if eng.name != 'sqlite': + self.assertEqual(1, query5.rows) + self.assertEqual(666, query5.limit) + self.assertEqual(True, query5.limit_used) + self.assertEqual(False, query5.select_as_cta) + self.assertEqual(False, query5.select_as_cta_used) - time.sleep(1) + def test_run_async_query(self): + main_db = db.session.query(models.Database).filter_by( + database_name="main").first() + eng = main_db.get_sqla_engine() - # DB #0 doesn't exist. - expected_result1 = { - 'error': 'Database with id 0 is missing.', - 'success': False - } - self.assertEqual( - sorted(expected_result1.items()), - sorted(celery_task1.get().items()) - ) - session2 = db.create_scoped_session() - query2 = session2.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - self.assertTrue('error' in celery_task2.get()) - expected_result3 = { - 'tmp_table': 'tmp_3_2', - 'success': True - } - self.assertEqual( - sorted(expected_result3.items()), - sorted(celery_task3.get().items()) - ) - expected_result4 = { - 'tmp_table': 'tmp_4_2', - 'success': True - } - self.assertEqual( - sorted(expected_result4.items()), - sorted(celery_task4.get().items()) - ) + # Schedule queries - session = db.create_scoped_session() - query1 = session.query(models.Query).filter_by( - sql='SELECT * FROM dontexist').first() - self.assertIsNone(query1) - query2 = session.query(models.Query).filter_by( - sql='SELECT * FROM dontexist1').first() - self.assertEqual(models.QueryStatus.FAILED, query2.status) - query3 = session.query(models.Query).filter_by( - sql=where_query).first() - self.assertEqual(models.QueryStatus.FINISHED, query3.status) - query4 = session.query(models.Query).filter_by( - sql='SELECT * FROM ab_permission WHERE id=666').first() - self.assertEqual(models.QueryStatus.FINISHED, query4.status) - session.close() + # Case 1. + # Table and DB exists, async CTA call to the backend. + sql_where = "SELECT name FROM ab_role WHERE name='Admin'" + result1 = self.run_sql( + 1, sql_where, async='True', tmp_table='tmp_async_1', cta='True') + self.assertEqual(models.QueryStatus.SCHEDULED, result1[u'status']) + + # Case 2. + # Table and DB exists, async insert query, no CTAs. + insert_query = "INSERT INTO ab_role VALUES (9, 'fake_role')" + result2 = self.run_sql(1, insert_query, async='True') + self.assertEqual(models.QueryStatus.SCHEDULED, result2[u'status']) + + time.sleep(2) + + # Case 1. + query1 = self.get_query_by_id(result1[u'query_id']) + df1 = pd.read_sql_query(query1.select_sql, con=eng) + self.assertEqual(models.QueryStatus.FINISHED, query1.status) + self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records')) + self.assertEqual(models.QueryStatus.FINISHED, query1.status) + self.assertTrue("SELECT * \nFROM tmp_async_1" in query1.select_sql) + self.assertTrue("LIMIT 666" in query1.select_sql) + self.assertEqual( + "CREATE TABLE tmp_async_1 AS SELECT name FROM ab_role " + "WHERE name='Admin'", query1.executed_sql) + self.assertEqual(sql_where, query1.sql) + if eng.name != 'sqlite': + self.assertEqual(1, query1.rows) + self.assertEqual(666, query1.limit) + self.assertEqual(False, query1.limit_used) + self.assertEqual(True, query1.select_as_cta) + self.assertEqual(True, query1.select_as_cta_used) + + # Case 2. + query2 = self.get_query_by_id(result2[u'query_id']) + self.assertEqual(models.QueryStatus.FINISHED, query2.status) + self.assertIsNone(query2.select_sql) + self.assertEqual(insert_query, query2.executed_sql) + self.assertEqual(insert_query, query2.sql) + if eng.name != 'sqlite': + self.assertEqual(1, query2.rows) + self.assertEqual(666, query2.limit) + self.assertEqual(False, query2.limit_used) + self.assertEqual(False, query2.select_as_cta) + self.assertEqual(False, query2.select_as_cta_used) if __name__ == '__main__': diff --git a/tests/core_tests.py b/tests/core_tests.py index 48d26c16e96..ed33c611e0f 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -321,7 +321,7 @@ class CoreTests(CaravelTestCase): ) resp = self.client.post( '/caravel/sql_json/', - data=dict(database_id=dbid, sql=sql), + data=dict(database_id=dbid, sql=sql, select_as_create_as=False), ) self.logout() return json.loads(resp.data.decode('utf-8')) @@ -340,9 +340,9 @@ class CoreTests(CaravelTestCase): db.session.commit() main_db_permission_view = ( db.session.query(ab_models.PermissionView) - .join(ab_models.ViewMenu) - .filter(ab_models.ViewMenu.name == '[main].(id:1)') - .first() + .join(ab_models.ViewMenu) + .filter(ab_models.ViewMenu.name == '[main].(id:1)') + .first() ) astronaut = sm.add_role("Astronaut") sm.add_permission_role(astronaut, main_db_permission_view)