diff --git a/caravel/config.py b/caravel/config.py
index 1930ffa0521..869dd4c35f9 100644
--- a/caravel/config.py
+++ b/caravel/config.py
@@ -178,6 +178,7 @@ BACKUP_COUNT = 30
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = ""
+
# Maximum number of rows returned in the SQL editor
SQL_MAX_ROW = 1000
@@ -192,10 +193,10 @@ WARNING_MSG = None
"""
# Example:
class CeleryConfig(object):
- BROKER_URL = 'sqla+sqlite:///celerydb.sqlite'
- CELERY_IMPORTS = ('caravel.tasks', )
- CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite'
- CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}}
+ BROKER_URL = 'sqla+sqlite:///celerydb.sqlite'
+ CELERY_IMPORTS = ('caravel.tasks', )
+ CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite'
+ CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}}
CELERY_CONFIG = CeleryConfig
"""
CELERY_CONFIG = None
@@ -207,4 +208,3 @@ except ImportError:
if not CACHE_DEFAULT_TIMEOUT:
CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT')
-
diff --git a/caravel/extract_table_names.py b/caravel/extract_table_names.py
new file mode 100644
index 00000000000..4bc57074290
--- /dev/null
+++ b/caravel/extract_table_names.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com
+#
+# This example is part of python-sqlparse and is released under
+# the BSD License: http://www.opensource.org/licenses/bsd-license.php
+#
+# This example illustrates how to extract table names from nested
+# SELECT statements.
+#
+# See:
+# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895
+
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier
+from sqlparse.tokens import Keyword, DML
+
+
+def is_subselect(parsed):
+ if not parsed.is_group():
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() == 'SELECT':
+ return True
+ return False
+
+
+def extract_from_part(parsed):
+ from_seen = False
+ for item in parsed.tokens:
+ if from_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item):
+ yield x
+ elif item.ttype is Keyword:
+ raise StopIteration
+ else:
+ yield item
+ elif item.ttype is Keyword and item.value.upper() == 'FROM':
+ from_seen = True
+
+
+def extract_table_identifiers(token_stream):
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ yield identifier.get_name()
+ elif isinstance(item, Identifier):
+ yield item.get_name()
+ # It's a bug to check for Keyword here, but in the example
+ # above some tables names are identified as keywords...
+ elif item.ttype is Keyword:
+ yield item.value
+
+
+# TODO(bkyryliuk): add logic to support joins and unions.
+def extract_tables(sql):
+ stream = extract_from_part(sqlparse.parse(sql)[0])
+ return list(extract_table_identifiers(stream))
diff --git a/caravel/migrations/versions/ad82a75afd82_add_query_model.py b/caravel/migrations/versions/ad82a75afd82_add_query_model.py
index 4794f416de0..31e41afbada 100644
--- a/caravel/migrations/versions/ad82a75afd82_add_query_model.py
+++ b/caravel/migrations/versions/ad82a75afd82_add_query_model.py
@@ -13,17 +13,27 @@ down_revision = 'f162a1dea4c4'
from alembic import op
import sqlalchemy as sa
+
def upgrade():
op.create_table('query',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('database_id', sa.Integer(), nullable=False),
- sa.Column('tmp_table_name', sa.String(length=64), nullable=True),
+ sa.Column('tmp_table_name', sa.String(length=256), nullable=True),
+ sa.Column('tab_name', sa.String(length=256),nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=16), nullable=True),
- sa.Column('name', sa.String(length=64), nullable=True),
- sa.Column('sql', sa.Text, nullable=True),
+ sa.Column('name', sa.String(length=256), nullable=True),
+ sa.Column('schema', sa.String(length=256), nullable=True),
+ sa.Column('sql', sa.Text(), nullable=True),
+ sa.Column('select_sql', sa.Text(), nullable=True),
+ sa.Column('executed_sql', sa.Text(), nullable=True),
sa.Column('limit', sa.Integer(), nullable=True),
+ sa.Column('limit_used', sa.Boolean(), nullable=True),
+ sa.Column('select_as_cta', sa.Boolean(), nullable=True),
+ sa.Column('select_as_cta_used', sa.Boolean(), nullable=True),
sa.Column('progress', sa.Integer(), nullable=True),
+ sa.Column('rows', sa.Integer(), nullable=True),
+ sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('start_time', sa.DateTime(), nullable=True),
sa.Column('end_time', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['database_id'], [u'dbs.id'], ),
diff --git a/caravel/models.py b/caravel/models.py
index ca57a1f6668..404386ecbc8 100644
--- a/caravel/models.py
+++ b/caravel/models.py
@@ -379,7 +379,7 @@ class Database(Model, AuditMixinNullable):
sqlalchemy_uri = Column(String(1024))
password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
cache_timeout = Column(Integer)
- select_as_create_table_as = Column(Boolean, default=True)
+ select_as_create_table_as = Column(Boolean, default=False)
extra = Column(Text, default=textwrap.dedent("""\
{
"metadata_params": {},
@@ -1734,6 +1734,16 @@ class FavStar(Model):
class QueryStatus:
+ def from_presto_states(self, presto_status):
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+
SCHEDULED = 'SCHEDULED'
CANCELLED = 'CANCELLED'
IN_PROGRESS = 'IN_PROGRESS'
@@ -1752,18 +1762,30 @@ class Query(Model):
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
# Store the tmp table into the DB only if the user asks for it.
- tmp_table_name = Column(String(64))
+ tmp_table_name = Column(String(256))
user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True)
# models.QueryStatus
status = Column(String(16))
- name = Column(String(64))
+ name = Column(String(256))
+ tab_name = Column(String(256))
+ schema = Column(String(256))
sql = Column(Text)
- # Could be configured in the caravel config
+ # Query to retrieve the results,
+ # used only in case of select_as_cta_used is true.
+ select_sql = Column(Text)
+ executed_sql = Column(Text)
+ # Could be configured in the caravel config.
limit = Column(Integer)
+ limit_used = Column(Boolean)
+ select_as_cta = Column(Boolean)
+ select_as_cta_used = Column(Boolean)
# 1..100
progress = Column(Integer)
+ # # of rows in the result set or rows modified.
+ rows = Column(Integer)
+ error_message = Column(Text)
start_time = Column(DateTime)
end_time = Column(DateTime)
diff --git a/caravel/sql_lab.py b/caravel/sql_lab.py
new file mode 100644
index 00000000000..b7b5c3152ab
--- /dev/null
+++ b/caravel/sql_lab.py
@@ -0,0 +1,153 @@
+import celery
+from caravel import models, app, utils, sql_lab_utils
+from datetime import datetime
+
+celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
+
+
+@celery_app.task
+def get_sql_results(query_id):
+ """Executes the sql query returns the results."""
+ sql_manager = QueryRunner(query_id)
+ sql_manager.run_sql()
+ # Return the result for the sync call.
+ # if self.request.called_directly:
+ if sql_manager.query().status == models.QueryStatus.FINISHED:
+ return {
+ 'query_id': sql_manager.query().id,
+ 'status': sql_manager.query().status,
+ 'data': sql_manager.data(),
+ 'columns': sql_manager.columns(),
+ }
+ else:
+ return {
+ 'query_id': sql_manager.query().id,
+ 'status': sql_manager.query().status,
+ 'error': sql_manager.query().error_message,
+ }
+
+
+class QueryRunner:
+ def __init__(self, query_id):
+ self._query_id = query_id
+ # Creates a separate session, reusing the db.session leads to the
+ # concurrency issues.
+ self._session = sql_lab_utils.create_scoped_session()
+ self._query = self._session.query(models.Query).filter_by(
+ id=query_id).first()
+ self._db_to_query = self._session.query(models.Database).filter_by(
+ id=self._query.database_id).first()
+ # Query result.
+ self._data = None
+ self._columns = None
+
+ def _sanity_check(self):
+ if not self._query:
+ self._query.error_message = "Query with id {0} not found.".format(
+ self._query_id)
+ if not self._db_to_query:
+ self._query.error_message = (
+ "Database with id {0} is missing.".format(
+ self._query.database_id)
+ )
+
+ if self._query.error_message:
+ self._query.status = models.QueryStatus.FAILED
+ self._session.flush()
+ return False
+ return True
+
+ def query(self):
+ return self._query
+
+ def data(self):
+ return self._data
+
+ def columns(self):
+ return self._columns
+
+ def run_sql(self):
+ if not self._sanity_check():
+ return self._query.status
+
+ # TODO(bkyryliuk): dump results somewhere for the webserver.
+ engine = self._db_to_query.get_sqla_engine(schema=self._query.schema)
+ self._query.executed_sql = self._query.sql.strip().strip(';')
+
+ # Limit enforced only for retrieving the data, not for the CTA queries.
+ self._query.select_as_cta_used = False
+ self._query.limit_used = False
+ if sql_lab_utils.is_query_select(self._query.sql):
+ if self._query.select_as_cta:
+ if not self._query.tmp_table_name:
+ self._query.tmp_table_name = 'tmp_{}_table_{}'.format(
+ self._query.user_id,
+ self._query.start_time.strftime('%Y_%m_%d_%H_%M_%S'))
+ self._query.executed_sql = sql_lab_utils.create_table_as(
+ self._query.executed_sql, self._query.tmp_table_name)
+ self._query.select_as_cta_used = True
+ elif self._query.limit:
+ self._query.executed_sql = sql_lab_utils.add_limit_to_the_sql(
+ self._query.executed_sql, self._query.limit, engine)
+ self._query.limit_used = True
+
+ # TODO(bkyryliuk): ensure that tmp table was created.
+ # Do not set tmp table name if table wasn't created.
+ if not self._query.select_as_cta_used:
+ self._query.tmp_table_name = None
+ self._get_sql_results(engine)
+
+ self._query.end_time = datetime.now()
+ self._session.flush()
+ return self._query.status
+
+ def _get_sql_results(self, engine):
+ try:
+ result_proxy = engine.execute(
+ self._query.executed_sql, schema=self._query.schema)
+ except Exception as e:
+ self._query.error_message = utils.error_msg_from_exception(e)
+ self._query.status = models.QueryStatus.FAILED
+ return
+
+ cursor = result_proxy.cursor
+ if hasattr(cursor, "poll"):
+ query_stats = cursor.poll()
+ self._query.status = models.QueryStatus.IN_PROGRESS
+ self._session.flush()
+ # poll returns dict -- JSON status information or ``None``
+ # if the query is done
+ # https://github.com/dropbox/PyHive/blob/
+ # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
+ while query_stats:
+ # Update the object and wait for the kill signal.
+ self._session.refresh(self._query)
+ completed_splits = int(query_stats['stats']['completedSplits'])
+ total_splits = int(query_stats['stats']['totalSplits'])
+ progress = 100 * completed_splits / total_splits
+ if progress > self._query.progress:
+ self._query.progress = progress
+
+ self._session.flush()
+ query_stats = cursor.poll()
+ # TODO(b.kyryliuk): check for the kill signal.
+
+ sql_results = sql_lab_utils.fetch_response_from_cursor(
+ result_proxy)
+ self._columns = sql_results['columns']
+ self._data = sql_results['data']
+ self._query.rows = result_proxy.rowcount
+ self._query.progress = 100
+ self._query.status = models.QueryStatus.FINISHED
+ if self._query.rows == -1 and self._data:
+ # Presto doesn't provide result_proxy.row_count
+ self._query.rows = len(self._data)
+
+ # CTAs queries result in 1 cell having the # of the added rows.
+ if self._query.select_as_cta_used:
+ self._query.select_sql = sql_lab_utils.select_star(
+ engine, self._query.tmp_table_name, self._query.limit)
+ else:
+ self._query.tmp_table = None
+
+
diff --git a/caravel/sql_lab_utils.py b/caravel/sql_lab_utils.py
new file mode 100644
index 00000000000..085a953667a
--- /dev/null
+++ b/caravel/sql_lab_utils.py
@@ -0,0 +1,125 @@
+# SQL Lab Utils
+import pandas as pd
+
+import sqlparse
+from caravel import models, app
+from sqlalchemy import create_engine
+from sqlalchemy.orm import scoped_session, sessionmaker
+from sqlalchemy import select, text
+from sqlalchemy.sql.expression import TextAsFrom
+
+
+def create_scoped_session():
+ """Creates new SQLAlchemy scoped_session."""
+ engine = create_engine(
+ app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True)
+ return scoped_session(sessionmaker(
+ autocommit=True, autoflush=False, bind=engine))
+
+
+def fetch_response_from_cursor(result_proxy):
+ columns = None
+ data = None
+ if result_proxy.cursor:
+ cols = [col[0] for col in result_proxy.cursor.description]
+ data = result_proxy.fetchall()
+ df = pd.DataFrame(data, columns=cols)
+ df = df.fillna(0)
+ columns = [c for c in df.columns]
+ data = df.to_dict(orient='records')
+ return {
+ 'columns': columns,
+ 'data': data,
+ }
+
+
+def is_query_select(sql):
+ try:
+ return sqlparse.parse(sql)[0].get_type() == 'SELECT'
+ # Capture sqlparse exceptions, worker shouldn't fail here.
+ except Exception:
+ # TODO(bkyryliuk): add logging here.
+ return False
+
+
+# if sqlparse provides the stream of tokens but don't provide the API
+# to access the table names, more on it:
+# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU
+# https://github.com/andialbrecht/sqlparse/blob/master/examples/
+# extract_table_names.py
+#
+# Another approach would be to run the EXPLAIN on the sql statement:
+# https://prestodb.io/docs/current/sql/explain.html
+# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain
+def get_tables():
+ """Retrieves the query names from the query."""
+ # TODO(bkyryliuk): implement parsing the sql statement.
+ pass
+
+
+def select_star(engine, table_name, limit):
+ if limit:
+ select_star_sql = select('*').select_from(table_name).limit(limit)
+ else:
+ select_star_sql = select('*').select_from(table_name)
+
+ # SQL code to preview the results
+ return '{}'.format(select_star_sql.compile(
+ engine, compile_kwargs={"literal_binds": True}))
+
+
+def add_limit_to_the_sql(sql, limit, eng):
+ # Treat as single sql statement in case of failure.
+ try:
+ sql_statements = [s for s in sqlparse.split(sql) if s]
+ except Exception as e:
+ app.logger.info(
+ "Statement " + sql + "failed to be transformed to have the limit "
+
+
+ "with the exception" + e.message)
+ return sql
+ if len(sql_statements) == 1 and is_query_select(sql):
+ qry = select('*').select_from(
+ TextAsFrom(text(sql_statements[0]), ['*']).alias(
+ 'inner_qry')).limit(limit)
+ sql_statement = str(qry.compile(
+ eng, compile_kwargs={"literal_binds": True}))
+ return sql_statement
+ return sql
+
+
+# create table works only for the single statement.
+# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it
+# for the CTA operation.
+def create_table_as(sql, table_name, override=False):
+ """Reformats the query into the create table as query.
+
+ Works only for the single select SQL statements, in all other cases
+ the sql query is not modified.
+ :param sql: string, sql query that will be executed
+ :param table_name: string, will contain the results of the query execution
+ :param override, boolean, table table_name will be dropped if true
+ :return: string, create table as query
+ """
+ # TODO(bkyryliuk): drop table if allowed, check the namespace and
+ # the permissions.
+ # Treat as single sql statement in case of failure.
+ try:
+ # Filter out empty statements.
+ sql_statements = [s for s in sqlparse.split(sql) if s]
+ except Exception as e:
+ app.logger.info(
+ "Statement " + sql + "failed to be transformed as create table as "
+ "with the exception" + e.message)
+ return sql
+ if len(sql_statements) == 1 and is_query_select(sql):
+ updated_sql = ''
+ # TODO(bkyryliuk): use sqlalchemy statements for the
+ # the drop and create operations.
+ if override:
+ updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name)
+ updated_sql += "CREATE TABLE %s AS %s" % (
+ table_name, sql_statements[0])
+ return updated_sql
+ return sql
diff --git a/caravel/tasks.py b/caravel/tasks.py
deleted file mode 100644
index c48e6699745..00000000000
--- a/caravel/tasks.py
+++ /dev/null
@@ -1,219 +0,0 @@
-import celery
-from caravel import models, app, utils
-from datetime import datetime
-import logging
-from sqlalchemy import create_engine, select, text
-from sqlalchemy.orm import scoped_session, sessionmaker
-from sqlalchemy.sql.expression import TextAsFrom
-import sqlparse
-import pandas as pd
-
-celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
-
-
-def is_query_select(sql):
- try:
- return sqlparse.parse(sql)[0].get_type() == 'SELECT'
- # Capture sqlparse exceptions, worker shouldn't fail here.
- except Exception:
- # TODO(bkyryliuk): add logging here.
- return False
-
-
-# if sqlparse provides the stream of tokens but don't provide the API
-# to access the table names, more on it:
-# https://groups.google.com/forum/#!topic/sqlparse/sL2aAi6dSJU
-# https://github.com/andialbrecht/sqlparse/blob/master/examples/
-# extract_table_names.py
-#
-# Another approach would be to run the EXPLAIN on the sql statement:
-# https://prestodb.io/docs/current/sql/explain.html
-# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Explain
-def get_tables():
- """Retrieves the query names from the query."""
- # TODO(bkyryliuk): implement parsing the sql statement.
- pass
-
-
-def add_limit_to_the_query(sql, limit, eng):
- # Treat as single sql statement in case of failure.
- sql_statements = [sql]
- try:
- sql_statements = [s for s in sqlparse.split(sql) if s]
- except Exception as e:
- logging.info(
- "Statement " + sql + "failed to be transformed to have the limit "
- "with the exception" + e.message)
- return sql
- if len(sql_statements) == 1 and is_query_select(sql):
- qry = select('*').select_from(
- TextAsFrom(text(sql_statements[0]), ['*']).alias(
- 'inner_qry')).limit(limit)
- sql_statement = str(qry.compile(
- eng, compile_kwargs={"literal_binds": True}))
- return sql_statement
- return sql
-
-
-# create table works only for the single statement.
-def create_table_as(sql, table_name, override=False):
- """Reformats the query into the create table as query.
-
- Works only for the single select SQL statements, in all other cases
- the sql query is not modified.
- :param sql: string, sql query that will be executed
- :param table_name: string, will contain the results of the query execution
- :param override, boolean, table table_name will be dropped if true
- :return: string, create table as query
- """
- # TODO(bkyryliuk): drop table if allowed, check the namespace and
- # the permissions.
- # Treat as single sql statement in case of failure.
- sql_statements = [sql]
- try:
- # Filter out empty statements.
- sql_statements = [s for s in sqlparse.split(sql) if s]
- except Exception as e:
- logging.info(
- "Statement " + sql + "failed to be transformed as create table as "
- "with the exception" + e.message)
- return sql
- if len(sql_statements) == 1 and is_query_select(sql):
- updated_sql = ''
- # TODO(bkyryliuk): use sqlalchemy statements for the
- # the drop and create operations.
- if override:
- updated_sql = 'DROP TABLE IF EXISTS {};\n'.format(table_name)
- updated_sql += "CREATE TABLE %s AS %s" % (
- table_name, sql_statements[0])
- return updated_sql
- return sql
-
-
-def get_session():
- """Creates new SQLAlchemy scoped_session."""
- engine = create_engine(
- app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True)
- return scoped_session(sessionmaker(
- autocommit=False, autoflush=False, bind=engine))
-
-
-@celery_app.task
-def get_sql_results(database_id, sql, user_id, tmp_table_name="", schema=None):
- """Executes the sql query returns the results.
-
- :param database_id: integer
- :param sql: string, query that will be executed
- :param user_id: integer
- :param tmp_table_name: name of the table for CTA
- :param schema: string, name of the schema (used in presto)
- :return: dataframe, query result
- """
- # Create a separate session, reusing the db.session leads to the
- # concurrency issues.
- session = get_session()
- try:
- db_to_query = (
- session.query(models.Database).filter_by(id=database_id).first()
- )
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
- if not db_to_query:
- return {
- 'error': "Database with id {0} is missing.".format(database_id),
- 'success': False,
- }
-
- # TODO(bkyryliuk): provide a way for the user to name the query.
- # TODO(bkyryliuk): run explain query to derive the tables and fill in the
- # table_ids
- # TODO(bkyryliuk): check the user permissions
- # TODO(bkyryliuk): store the tab name in the query model
- limit = app.config.get('SQL_MAX_ROW', None)
- start_time = datetime.now()
- if not tmp_table_name:
- tmp_table_name = 'tmp.{}_table_{}'.format(user_id, start_time)
- query = models.Query(
- user_id=user_id,
- database_id=database_id,
- limit=limit,
- name='{}'.format(start_time),
- sql=sql,
- start_time=start_time,
- tmp_table_name=tmp_table_name,
- status=models.QueryStatus.IN_PROGRESS,
- )
- session.add(query)
- session.commit()
- query_result = get_sql_results_as_dict(
- db_to_query, sql, query.tmp_table_name, schema=schema)
- query.end_time = datetime.now()
- if query_result['success']:
- query.status = models.QueryStatus.FINISHED
- else:
- query.status = models.QueryStatus.FAILED
- session.commit()
- # TODO(bkyryliuk): return the tmp table / query_id
- return query_result
-
-
-# TODO(bkyryliuk): merge the changes made in the carapal first
-# before merging this PR.
-def get_sql_results_as_dict(db_to_query, sql, tmp_table_name, schema=None):
- """Get the SQL query results from the give session and db connection.
-
- :param sql: string, query that will be executed
- :param db_to_query: models.Database to query, cannot be None
- :param tmp_table_name: name of the table for CTA
- :param schema: string, name of the schema (used in presto)
- :return: (dataframe, boolean), results and the status
- """
- eng = db_to_query.get_sqla_engine(schema=schema)
- sql = sql.strip().strip(';')
- # TODO(bkyryliuk): fix this case for multiple statements
- if app.config.get('SQL_MAX_ROW'):
- sql = add_limit_to_the_query(
- sql, app.config.get("SQL_MAX_ROW"), eng)
-
- cta_used = False
- if (app.config.get('SQL_SELECT_AS_CTA') and
- db_to_query.select_as_create_table_as and is_query_select(sql)):
- # TODO(bkyryliuk): figure out if the query is select query.
- sql = create_table_as(sql, tmp_table_name)
- cta_used = True
-
- if cta_used:
- try:
- eng.execute(sql)
- return {
- 'tmp_table': tmp_table_name,
- 'success': True,
- }
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
-
- # otherwise run regular SQL query.
- # TODO(bkyryliuk): rewrite into eng.execute as queries different from
- # select should be permitted too.
- try:
- df = db_to_query.get_df(sql, schema)
- df = df.fillna(0)
- return {
- 'columns': [c for c in df.columns],
- 'data': df.to_dict(orient='records'),
- 'success': True,
- }
-
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
-
-
diff --git a/caravel/utils.py b/caravel/utils.py
index 9b784517c57..668c80f4936 100644
--- a/caravel/utils.py
+++ b/caravel/utils.py
@@ -339,10 +339,17 @@ def error_msg_from_exception(e):
Database have different ways to handle exception. This function attempts
to make sense of the exception object and construct a human readable
sentence.
+
+ TODO(bkyryliuk): parse the Presto error message from the connection
+ created via create_engine.
+ engine = create_engine('presto://localhost:3506/silver') -
+ gives an e.message as the str(dict)
+ presto.connect("localhost", port=3506, catalog='silver') - as a dict.
+ The latter version is parsed correctly by this function.
"""
msg = ''
if hasattr(e, 'message'):
- if (type(e.message) is dict):
+ if type(e.message) is dict:
msg = e.message.get('message')
elif e.message:
msg = "{}".format(e.message)
diff --git a/caravel/views.py b/caravel/views.py
index 5bd50c554c9..2fdc418946e 100755
--- a/caravel/views.py
+++ b/caravel/views.py
@@ -33,7 +33,7 @@ from wtforms.validators import ValidationError
import caravel
from caravel import (
- appbuilder, db, models, viz, utils, app, sm, ascii_art, tasks
+ appbuilder, db, models, viz, utils, app, sm, ascii_art, sql_lab
)
config = app.config
@@ -459,11 +459,13 @@ class DatabaseAsync(DatabaseView):
appbuilder.add_view_no_menu(DatabaseAsync)
+
class DatabaseTablesAsync(DatabaseView):
list_columns = ['id', 'all_table_names', 'all_schema_names']
appbuilder.add_view_no_menu(DatabaseTablesAsync)
+
class TableModelView(CaravelModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable)
list_columns = [
@@ -622,7 +624,8 @@ class SliceModelView(CaravelModelView, DeleteMixin): # noqa
url = "/druiddatasourcemodelview/list/"
msg = _(
"Click on a datasource link to create a Slice, "
- "or click on a table link here "
+ "or click on a table link "
+ "here "
"to create a Slice for a table"
)
else:
@@ -904,7 +907,8 @@ class Caravel(BaseCaravelView):
datasource_access = self.can_access(
'datasource_access', datasource.perm)
if not (all_datasource_access or datasource_access):
- flash(__("You don't seem to have access to this datasource"), "danger")
+ flash(__("You don't seem to have access to this datasource"),
+ "danger")
return redirect(error_redirect)
action = request.args.get('action')
@@ -981,7 +985,8 @@ class Caravel(BaseCaravelView):
del d['action']
del d['previous_viz_type']
- as_list = ('metrics', 'groupby', 'columns', 'all_columns', 'mapbox_label', 'order_by_cols')
+ as_list = ('metrics', 'groupby', 'columns', 'all_columns',
+ 'mapbox_label', 'order_by_cols')
for k in d:
v = d.get(k)
if k in as_list and not isinstance(v, list):
@@ -1092,7 +1097,8 @@ class Caravel(BaseCaravelView):
.group_by(Log.dt)
.all()
)
- payload = {str(time.mktime(dt.timetuple())): ccount for dt, ccount in qry if dt}
+ payload = {str(time.mktime(dt.timetuple())):
+ ccount for dt, ccount in qry if dt}
return Response(json.dumps(payload), mimetype="application/json")
@api
@@ -1148,9 +1154,11 @@ class Caravel(BaseCaravelView):
data = json.loads(request.form.get('data'))
session = db.session()
Slice = models.Slice # noqa
- dash = session.query(models.Dashboard).filter_by(id=dashboard_id).first()
+ dash = (
+ session.query(models.Dashboard).filter_by(id=dashboard_id).first())
check_ownership(dash, raise_if_false=True)
- new_slices = session.query(Slice).filter(Slice.id.in_(data['slice_ids']))
+ new_slices = session.query(Slice).filter(
+ Slice.id.in_(data['slice_ids']))
dash.slices += new_slices
session.merge(dash)
session.commit()
@@ -1184,13 +1192,18 @@ class Caravel(BaseCaravelView):
FavStar = models.FavStar # noqa
count = 0
favs = session.query(FavStar).filter_by(
- class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()).all()
+ class_name=class_name, obj_id=obj_id,
+ user_id=g.user.get_id()).all()
if action == 'select':
if not favs:
session.add(
FavStar(
- class_name=class_name, obj_id=obj_id, user_id=g.user.get_id(),
- dttm=datetime.now()))
+ class_name=class_name,
+ obj_id=obj_id,
+ user_id=g.user.get_id(),
+ dttm=datetime.now()
+ )
+ )
count = 1
elif action == 'unselect':
for fav in favs:
@@ -1396,9 +1409,24 @@ class Caravel(BaseCaravelView):
sql = request.form.get('sql')
database_id = request.form.get('database_id')
schema = request.form.get('schema')
+ tab_name = request.form.get('tab_name')
+
+ async = request.form.get('async') == 'True'
+ tmp_table_name = request.form.get('tmp_table_name', None)
+ select_as_cta = request.form.get('select_as_cta') == 'True'
+
session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).first()
+ if not mydb:
+ return Response(
+ json.dumps({
+ 'error': 'Database with id 0 is missing.',
+ 'status': models.QueryStatus.FAILED,
+ }),
+ status=500,
+ mimetype="application/json")
+
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
@@ -1406,19 +1434,132 @@ class Caravel(BaseCaravelView):
"SQL Lab requires the `all_datasource_access` or "
"specific DB permission"))
- data = tasks.get_sql_results(database_id, sql, g.user.get_id(),
- schema=schema)
- if 'error' in data:
+ start_time = datetime.now()
+ query_name = '{}_{}_{}'.format(
+ g.user.get_id(), tab_name, start_time.strftime('%M:%S:%f'))
+
+ query = models.Query(
+ database_id=database_id,
+ limit=app.config.get('SQL_MAX_ROW', None),
+ name=query_name,
+ sql=sql,
+ schema=schema,
+ # TODO(bkyryliuk): consider it being DB property.
+ select_as_cta=select_as_cta,
+ start_time=start_time,
+ status=models.QueryStatus.SCHEDULED,
+ tab_name=tab_name,
+ tmp_table_name=tmp_table_name,
+ user_id=g.user.get_id(),
+ )
+ session.add(query)
+ session.commit()
+
+ # Async request.
+ if async:
+ # Ignore the celery future object and the request may time out.
+ sql_lab.get_sql_results.delay(query.id)
+ return Response(json.dumps(
+ {
+ 'query_id': query.id,
+ 'status': query.status,
+ },
+ default=utils.json_int_dttm_ser, allow_nan=False),
+ status=202, # Accepted
+ mimetype="application/json")
+
+ # Sync request.
+ data = sql_lab.get_sql_results(query.id)
+ if data['status'] == models.QueryStatus.FAILED:
return Response(
- json.dumps(data),
+ json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False),
status=500,
mimetype="application/json")
- if 'tmp_table' in data:
- # TODO(bkyryliuk): add query id to the response and implement the
- # endpoint to poll the status and results.
- return None
- return json.dumps(
- data, default=utils.json_int_dttm_ser, allow_nan=False)
+ return Response(
+ json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False),
+ status=200,
+ mimetype="application/json")
+
+ @has_access
+ @expose("/queries/", methods=['GET'])
+ @log_this
+ def queries(self):
+ """Runs arbitrary sql and returns and json"""
+ last_updated = request.form.get('timestamp')
+ s = db.session()
+ query = s.query(models.Query).filter_by(id=query_id).first()
+ mydb = s.query(models.Database).filter_by(id=query.database_id).first()
+
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
+ raise utils.CaravelSecurityException(_(
+ "SQL Lab requires the `all_datasource_access` or "
+ "specific DB permission"))
+
+ if query:
+ return Response(
+ json.dumps({
+ 'status': query.status,
+ 'progress': query.progress
+ }),
+ status=200,
+ mimetype="application/json")
+
+ return Response(
+ json.dumps({
+ 'error': "Query with id {} wasn't found".format(query_id),
+ }),
+ status=404,
+ mimetype="application/json")
+
+ @has_access
+ @expose("/cta_query_results/", methods=['GET'])
+ @log_this
+ def cta_query_results(self):
+ """Runs arbitrary sql and returns and json"""
+ query_id = request.form.get('query_id')
+ s = db.session()
+ query = s.query(models.Query).filter_by(id=query_id).first()
+ mydb = s.query(models.Database).filter_by(id=query.database_id).first()
+
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
+ raise utils.CaravelSecurityException(_(
+ "SQL Lab requires the `all_datasource_access` or "
+ "specific DB permission"))
+
+ if not query:
+ return Response(
+ json.dumps({
+ 'error': "Query with id {} wasn't found".format(query_id),
+ }),
+ status=404,
+ mimetype="application/json")
+
+ if query.status != models.QueryStatus.FINISHED:
+ return Response(
+ json.dumps({
+ 'error': "Query with id {} not finished yet".format(
+ query_id),
+ }),
+ status=400,
+ mimetype="application/json")
+ try:
+ data = mydb.get_df(query.select_sql, query.schema)
+ return Response(
+ json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False),
+ status=200,
+ mimetype="application/json")
+ except Exception as e:
+ return Response(
+ json.dumps('error', utils.error_msg_from_exception(e)),
+ status=500,
+ mimetype="application/json")
@has_access
@expose("/refresh_datasources/")
diff --git a/setup.py b/setup.py
index b90bc599afd..fd3f6a987eb 100644
--- a/setup.py
+++ b/setup.py
@@ -30,6 +30,7 @@ setup(
'pandas==0.18.1',
'parsedatetime==2.0.0',
'pydruid==0.3.0',
+ 'PyHive>=0.2.1',
'python-dateutil==2.5.3',
'requests==2.10.0',
'simplejson==3.8.2',
@@ -37,6 +38,8 @@ setup(
'sqlalchemy==1.0.13',
'sqlalchemy-utils==0.32.7',
'sqlparse==0.1.19',
+ 'thrift>=0.9.3',
+ 'thrift-sasl>=0.2.1',
'werkzeug==0.11.10',
],
extras_require={
diff --git a/tests/caravel_test_config.py b/tests/caravel_test_config.py
index a28240a87f7..1c24ee640bb 100644
--- a/tests/caravel_test_config.py
+++ b/tests/caravel_test_config.py
@@ -13,12 +13,13 @@ if 'CARAVEL__SQLALCHEMY_DATABASE_URI' in os.environ:
SQL_CELERY_DB_FILE_PATH = os.path.join(DATA_DIR, 'celerydb.sqlite')
SQL_CELERY_RESULTS_DB_FILE_PATH = os.path.join(DATA_DIR, 'celery_results.sqlite')
SQL_SELECT_AS_CTA = True
+SQL_MAX_ROW = 666
class CeleryConfig(object):
BROKER_URL = 'sqla+sqlite:///' + SQL_CELERY_DB_FILE_PATH
- CELERY_IMPORTS = ('caravel.tasks', )
+ CELERY_IMPORTS = ('caravel.sql_lab', )
CELERY_RESULT_BACKEND = 'db+sqlite:///' + SQL_CELERY_RESULTS_DB_FILE_PATH
- CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}}
+ CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}}
CONCURRENCY = 1
CELERY_CONFIG = CeleryConfig
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index e88ae0fca1c..48a2c004cfe 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -1,36 +1,47 @@
"""Unit tests for Caravel Celery worker"""
-import datetime
import imp
+import json
import subprocess
-import os
-import pandas as pd
+
import time
+
+import os
+
+import pandas as pd
import unittest
import caravel
-from caravel import app, appbuilder, db, models, tasks, utils
+from caravel import app, appbuilder, db, models, sql_lab, sql_lab_utils, utils
-class CeleryConfig(object):
- BROKER_URL = 'sqla+sqlite:////tmp/celerydb.sqlite'
- CELERY_IMPORTS = ('caravel.tasks',)
- CELERY_RESULT_BACKEND = 'db+sqlite:////tmp/celery_results.sqlite'
- CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}}
-app.config['CELERY_CONFIG'] = CeleryConfig
-
BASE_DIR = app.config.get('BASE_DIR')
cli = imp.load_source('cli', BASE_DIR + '/bin/caravel')
+SQL_CELERY_DB_FILE_PATH = '/tmp/celerydb.sqlite'
+SQL_CELERY_RESULTS_DB_FILE_PATH = '/tmp/celery_results.sqlite'
+
+
+class CeleryConfig(object):
+ BROKER_URL = 'sqla+sqlite:///' + SQL_CELERY_DB_FILE_PATH
+ CELERY_IMPORTS = ('caravel.sql_lab', )
+ CELERY_RESULT_BACKEND = 'db+sqlite:///' + SQL_CELERY_RESULTS_DB_FILE_PATH
+ CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}}
+ CONCURRENCY = 1
+app.config['CELERY_CONFIG'] = CeleryConfig
+
+# TODO(bkyryliuk): add ability to run this test separately.
+
class UtilityFunctionTests(unittest.TestCase):
def test_create_table_as(self):
select_query = "SELECT * FROM outer_space;"
- updated_select_query = tasks.create_table_as(select_query, "tmp")
+ updated_select_query = sql_lab_utils.create_table_as(
+ select_query, "tmp")
self.assertEqual(
"CREATE TABLE tmp AS SELECT * FROM outer_space;",
updated_select_query)
- updated_select_query_with_drop = tasks.create_table_as(
+ updated_select_query_with_drop = sql_lab_utils.create_table_as(
select_query, "tmp", override=True)
self.assertEqual(
"DROP TABLE IF EXISTS tmp;\n"
@@ -38,24 +49,26 @@ class UtilityFunctionTests(unittest.TestCase):
updated_select_query_with_drop)
select_query_no_semicolon = "SELECT * FROM outer_space"
- updated_select_query_no_semicolon = tasks.create_table_as(
+ updated_select_query_no_semicolon = sql_lab_utils.create_table_as(
select_query_no_semicolon, "tmp")
self.assertEqual(
"CREATE TABLE tmp AS SELECT * FROM outer_space",
updated_select_query_no_semicolon)
incorrect_query = "SMTH WRONG SELECT * FROM outer_space"
- updated_incorrect_query = tasks.create_table_as(incorrect_query, "tmp")
+ updated_incorrect_query = sql_lab_utils.create_table_as(
+ incorrect_query, "tmp")
self.assertEqual(incorrect_query, updated_incorrect_query)
insert_query = "INSERT INTO stomach VALUES (beer, chips);"
- updated_insert_query = tasks.create_table_as(insert_query, "tmp")
+ updated_insert_query = sql_lab_utils.create_table_as(
+ insert_query, "tmp")
self.assertEqual(insert_query, updated_insert_query)
multi_line_query = (
"SELECT * FROM planets WHERE\n"
"Luke_Father = 'Darth Vader';")
- updated_multi_line_query = tasks.create_table_as(
+ updated_multi_line_query = sql_lab_utils.create_table_as(
multi_line_query, "tmp")
expected_updated_multi_line_query = (
"CREATE TABLE tmp AS SELECT * FROM planets WHERE\n"
@@ -64,7 +77,7 @@ class UtilityFunctionTests(unittest.TestCase):
expected_updated_multi_line_query,
updated_multi_line_query)
- updated_multi_line_query_with_drop = tasks.create_table_as(
+ updated_multi_line_query_with_drop = sql_lab_utils.create_table_as(
multi_line_query, "tmp", override=True)
expected_updated_multi_line_query_with_drop = (
"DROP TABLE IF EXISTS tmp;\n"
@@ -75,12 +88,12 @@ class UtilityFunctionTests(unittest.TestCase):
updated_multi_line_query_with_drop)
delete_query = "DELETE FROM planet WHERE name = 'Earth'"
- updated_delete_query = tasks.create_table_as(delete_query, "tmp")
+ updated_delete_query = sql_lab_utils.create_table_as(delete_query, "tmp")
self.assertEqual(delete_query, updated_delete_query)
create_table_as = (
"CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n")
- updated_create_table_as = tasks.create_table_as(
+ updated_create_table_as = sql_lab_utils.create_table_as(
create_table_as, "tmp")
self.assertEqual(create_table_as, updated_create_table_as)
@@ -97,7 +110,7 @@ class UtilityFunctionTests(unittest.TestCase):
"(B.TECH ,BE ,Degree ,MCA ,MiBA)\n "
"AND Having Brothers= Null AND Sisters =Null"
)
- updated_sql_procedure = tasks.create_table_as(sql_procedure, "tmp")
+ updated_sql_procedure = sql_lab_utils.create_table_as(sql_procedure, "tmp")
self.assertEqual(sql_procedure, updated_sql_procedure)
multiple_statements = """
@@ -107,7 +120,7 @@ class UtilityFunctionTests(unittest.TestCase):
SELECT standard_disclaimer, witty_remark FROM company_requirements;
select count(*) from developer_brain;
"""
- updated_multiple_statements = tasks.create_table_as(
+ updated_multiple_statements = sql_lab_utils.create_table_as(
multiple_statements, "tmp")
self.assertEqual(multiple_statements, updated_multiple_statements)
@@ -116,36 +129,49 @@ class CeleryTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(CeleryTestCase, self).__init__(*args, **kwargs)
self.client = app.test_client()
+
+ def get_query_by_name(self, sql):
+ session = db.create_scoped_session()
+ query = session.query(models.Query).filter_by(sql=sql).first()
+ session.close()
+ return query
+
+ def get_query_by_id(self, id):
+ session = db.create_scoped_session()
+ query = session.query(models.Query).filter_by(id=id).first()
+ session.close()
+ return query
+
+ @classmethod
+ def setUpClass(cls):
+ try:
+ os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH'))
+ except OSError as e:
+ app.logger.warn(str(e))
+ try:
+ os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH'))
+ except OSError as e:
+ app.logger.warn(str(e))
+
utils.init(caravel)
+
+ worker_command = BASE_DIR + '/bin/caravel worker'
+ subprocess.Popen(
+ worker_command, shell=True, stdout=subprocess.PIPE)
+
admin = appbuilder.sm.find_user('admin')
if not admin:
appbuilder.sm.add_user(
'admin', 'admin', ' user', 'admin@fab.org',
appbuilder.sm.find_role('Admin'),
password='general')
- utils.init(caravel)
-
- @classmethod
- def setUpClass(cls):
- try:
- os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH'))
- except OSError:
- pass
- try:
- os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH'))
- except OSError:
- pass
-
- worker_command = BASE_DIR + '/bin/caravel worker'
- subprocess.Popen(
- worker_command, shell=True, stdout=subprocess.PIPE)
cli.load_examples(load_test_data=True)
+
@classmethod
def tearDownClass(cls):
subprocess.call(
- "ps auxww | grep 'celeryd' | awk '{print $2}' | "
- "xargs kill -9",
+ "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9",
shell=True
)
subprocess.call(
@@ -160,14 +186,40 @@ class CeleryTestCase(unittest.TestCase):
def tearDown(self):
pass
+ def login(self, username='admin', password='general'):
+ resp = self.client.post(
+ '/login/',
+ data=dict(username=username, password=password),
+ follow_redirects=True)
+ assert 'Welcome' in resp.data.decode('utf-8')
+
+ def logout(self):
+ self.client.get('/logout/', follow_redirects=True)
+
+ def run_sql(self, dbid, sql, cta='False', tmp_table='tmp',
+ async='False'):
+ self.login()
+ resp = self.client.post(
+ '/caravel/sql_json/',
+ data=dict(
+ database_id=dbid,
+ sql=sql,
+ async=async,
+ select_as_cta=cta,
+ tmp_table_name=tmp_table,
+ ),
+ )
+ self.logout()
+ return json.loads(resp.data.decode('utf-8'))
+
def test_add_limit_to_the_query(self):
- query_session = tasks.get_session()
+ query_session = sql_lab_utils.create_scoped_session()
db_to_query = query_session.query(models.Database).filter_by(
id=1).first()
eng = db_to_query.get_sqla_engine()
select_query = "SELECT * FROM outer_space;"
- updated_select_query = tasks.add_limit_to_the_query(
+ updated_select_query = sql_lab_utils.add_limit_to_the_sql(
select_query, 100, eng)
# Different DB engines have their own spacing while compiling
# the queries, that's why ' '.join(query.split()) is used.
@@ -178,7 +230,7 @@ class CeleryTestCase(unittest.TestCase):
)
select_query_no_semicolon = "SELECT * FROM outer_space"
- updated_select_query_no_semicolon = tasks.add_limit_to_the_query(
+ updated_select_query_no_semicolon = sql_lab_utils.add_limit_to_the_sql(
select_query_no_semicolon, 100, eng)
self.assertTrue(
"SELECT * FROM (SELECT * FROM outer_space) AS inner_qry "
@@ -187,19 +239,19 @@ class CeleryTestCase(unittest.TestCase):
)
incorrect_query = "SMTH WRONG SELECT * FROM outer_space"
- updated_incorrect_query = tasks.add_limit_to_the_query(
+ updated_incorrect_query = sql_lab_utils.add_limit_to_the_sql(
incorrect_query, 100, eng)
self.assertEqual(incorrect_query, updated_incorrect_query)
insert_query = "INSERT INTO stomach VALUES (beer, chips);"
- updated_insert_query = tasks.add_limit_to_the_query(
+ updated_insert_query = sql_lab_utils.add_limit_to_the_sql(
insert_query, 100, eng)
self.assertEqual(insert_query, updated_insert_query)
multi_line_query = (
"SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';"
)
- updated_multi_line_query = tasks.add_limit_to_the_query(
+ updated_multi_line_query = sql_lab_utils.add_limit_to_the_sql(
multi_line_query, 100, eng)
self.assertTrue(
"SELECT * FROM (SELECT * FROM planets WHERE "
@@ -208,13 +260,13 @@ class CeleryTestCase(unittest.TestCase):
)
delete_query = "DELETE FROM planet WHERE name = 'Earth'"
- updated_delete_query = tasks.add_limit_to_the_query(
+ updated_delete_query = sql_lab_utils.add_limit_to_the_sql(
delete_query, 100, eng)
self.assertEqual(delete_query, updated_delete_query)
create_table_as = (
"CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n")
- updated_create_table_as = tasks.add_limit_to_the_query(
+ updated_create_table_as = sql_lab_utils.add_limit_to_the_sql(
create_table_as, 100, eng)
self.assertEqual(create_table_as, updated_create_table_as)
@@ -231,168 +283,145 @@ class CeleryTestCase(unittest.TestCase):
"(B.TECH ,BE ,Degree ,MCA ,MiBA)\n "
"AND Having Brothers= Null AND Sisters = Null"
)
- updated_sql_procedure = tasks.add_limit_to_the_query(
+ updated_sql_procedure = sql_lab_utils.add_limit_to_the_sql(
sql_procedure, 100, eng)
self.assertEqual(sql_procedure, updated_sql_procedure)
- def test_run_async_query_delay_get(self):
+ def test_run_sync_query(self):
main_db = db.session.query(models.Database).filter_by(
database_name="main").first()
eng = main_db.get_sqla_engine()
# Case 1.
# DB #0 doesn't exist.
- result1 = tasks.get_sql_results.delay(
- 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_1').get()
- expected_result1 = {
- 'error': 'Database with id 0 is missing.',
- 'success': False
- }
- self.assertEqual(
- sorted(expected_result1.items()),
- sorted(result1.items())
- )
- session1 = db.create_scoped_session()
- query1 = session1.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist').first()
- session1.close()
- self.assertIsNone(query1)
+ sql_dont_exist = 'SELECT * FROM dontexist'
+ result1 = self.run_sql(0, sql_dont_exist, cta='True')
+ self.assertEqual(models.QueryStatus.FAILED, result1[u'status'])
+ self.assertFalse(u'query_id' in result1)
+ self.assertEqual('Database with id 0 is missing.', result1['error'])
+ self.assertIsNone(self.get_query_by_name(sql_dont_exist))
# Case 2.
- session2 = db.create_scoped_session()
- query2 = session2.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- session2.close()
-
- result2 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_1').get()
+ # Table doesn't exist.
+ result2 = self.run_sql(1, sql_dont_exist, cta='True', )
self.assertTrue('error' in result2)
- session2 = db.create_scoped_session()
- query2 = session2.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
+ self.assertEqual(models.QueryStatus.FAILED, result1[u'status'])
+ query2 = self.get_query_by_id(result2[u'query_id'])
self.assertEqual(models.QueryStatus.FAILED, query2.status)
- session2.close()
# Case 3.
- where_query = (
- "SELECT name FROM ab_permission WHERE name='can_select_star'")
- result3 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_3_1').get()
- expected_result3 = {
- 'tmp_table': 'tmp_3_1',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result3.items()),
- sorted(result3.items())
- )
- session3 = db.create_scoped_session()
- query3 = session3.query(models.Query).filter_by(
- sql=where_query).first()
- session3.close()
- df3 = pd.read_sql_query(sql="SELECT * FROM tmp_3_1", con=eng)
+ # Table and DB exists, CTA call to the backend.
+ sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'"
+ result3 = self.run_sql(
+ 1, sql_where, tmp_table='tmp_table_3', cta='True')
+ self.assertEqual(models.QueryStatus.FINISHED, result3[u'status'])
+ self.assertIsNone(result3[u'data'])
+ self.assertIsNone(result3[u'columns'])
+ query3 = self.get_query_by_id(result3[u'query_id'])
+
+ # Check the data in the tmp table.
+ df3 = pd.read_sql_query(sql=query3.select_sql, con=eng)
data3 = df3.to_dict(orient='records')
- self.assertEqual(models.QueryStatus.FINISHED, query3.status)
- self.assertEqual([{'name': 'can_select_star'}], data3)
+ self.assertEqual([{'name': 'can_sql'}], data3)
# Case 4.
- result4 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM ab_permission WHERE id=666', 1,
- tmp_table_name='tmp_4_1').get()
- expected_result4 = {
- 'tmp_table': 'tmp_4_1',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result4.items()),
- sorted(result4.items())
- )
- session4 = db.create_scoped_session()
- query4 = session4.query(models.Query).filter_by(
- sql='SELECT * FROM ab_permission WHERE id=666').first()
- session4.close()
- df4 = pd.read_sql_query(sql="SELECT * FROM tmp_4_1", con=eng)
- data4 = df4.to_dict(orient='records')
+ # Table and DB exists, CTA call to the backend, no data.
+ sql_empty_result = 'SELECT * FROM ab_user WHERE id=666'
+ result4 = self.run_sql(
+ 1, sql_empty_result, tmp_table='tmp_table_4', cta='True',)
+ self.assertEqual(models.QueryStatus.FINISHED, result4[u'status'])
+ self.assertIsNone(result4[u'data'])
+ self.assertIsNone(result4[u'columns'])
+
+ query4 = self.get_query_by_id(result4[u'query_id'])
self.assertEqual(models.QueryStatus.FINISHED, query4.status)
+ self.assertTrue("SELECT * \nFROM tmp_table_4" in query4.select_sql)
+ self.assertTrue("LIMIT 666" in query4.select_sql)
+ self.assertEqual(
+ "CREATE TABLE tmp_table_4 AS SELECT * FROM ab_user WHERE id=666",
+ query4.executed_sql)
+ self.assertEqual("SELECT * FROM ab_user WHERE id=666", query4.sql)
+ if eng.name != 'sqlite':
+ self.assertEqual(0, query4.rows)
+ self.assertEqual(666, query4.limit)
+ self.assertEqual(False, query4.limit_used)
+ self.assertEqual(True, query4.select_as_cta)
+ self.assertEqual(True, query4.select_as_cta_used)
+
+ # Check the data in the tmp table.
+ df4 = pd.read_sql_query(sql=query4.select_sql, con=eng)
+ data4 = df4.to_dict(orient='records')
self.assertEqual([], data4)
# Case 5.
- # Return the data directly if DB select_as_create_table_as is False.
- main_db.select_as_create_table_as = False
- db.session.commit()
- result5 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_5_1').get()
- expected_result5 = {
- 'columns': ['name'],
- 'data': [{'name': 'can_select_star'}],
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result5.items()),
- sorted(result5.items())
- )
+ # Table and DB exists, select without CTA.
+ result5 = self.run_sql(1, sql_where, tmp_table='tmp_table_5')
+ self.assertEqual(models.QueryStatus.FINISHED, result5[u'status'])
+ self.assertEqual([u'name'], result5[u'columns'])
+ self.assertEqual([{u'name': u'can_sql'}], result5[u'data'])
- def test_run_async_query_delay(self):
- celery_task1 = tasks.get_sql_results.delay(
- 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_2')
- celery_task2 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_2')
- where_query = (
- "SELECT name FROM ab_permission WHERE name='can_select_star'")
- celery_task3 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_3_2')
- celery_task4 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM ab_permission WHERE id=666', 1,
- tmp_table_name='tmp_4_2')
+ query5 = self.get_query_by_id(result5[u'query_id'])
+ self.assertEqual(sql_where, query5.sql)
+ if eng.name != 'sqlite':
+ self.assertEqual(1, query5.rows)
+ self.assertEqual(666, query5.limit)
+ self.assertEqual(True, query5.limit_used)
+ self.assertEqual(False, query5.select_as_cta)
+ self.assertEqual(False, query5.select_as_cta_used)
- time.sleep(1)
+ def test_run_async_query(self):
+ main_db = db.session.query(models.Database).filter_by(
+ database_name="main").first()
+ eng = main_db.get_sqla_engine()
- # DB #0 doesn't exist.
- expected_result1 = {
- 'error': 'Database with id 0 is missing.',
- 'success': False
- }
- self.assertEqual(
- sorted(expected_result1.items()),
- sorted(celery_task1.get().items())
- )
- session2 = db.create_scoped_session()
- query2 = session2.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- self.assertTrue('error' in celery_task2.get())
- expected_result3 = {
- 'tmp_table': 'tmp_3_2',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result3.items()),
- sorted(celery_task3.get().items())
- )
- expected_result4 = {
- 'tmp_table': 'tmp_4_2',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result4.items()),
- sorted(celery_task4.get().items())
- )
+ # Schedule queries
- session = db.create_scoped_session()
- query1 = session.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist').first()
- self.assertIsNone(query1)
- query2 = session.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- query3 = session.query(models.Query).filter_by(
- sql=where_query).first()
- self.assertEqual(models.QueryStatus.FINISHED, query3.status)
- query4 = session.query(models.Query).filter_by(
- sql='SELECT * FROM ab_permission WHERE id=666').first()
- self.assertEqual(models.QueryStatus.FINISHED, query4.status)
- session.close()
+ # Case 1.
+ # Table and DB exists, async CTA call to the backend.
+ sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
+ result1 = self.run_sql(
+ 1, sql_where, async='True', tmp_table='tmp_async_1', cta='True')
+ self.assertEqual(models.QueryStatus.SCHEDULED, result1[u'status'])
+
+ # Case 2.
+ # Table and DB exists, async insert query, no CTAs.
+ insert_query = "INSERT INTO ab_role VALUES (9, 'fake_role')"
+ result2 = self.run_sql(1, insert_query, async='True')
+ self.assertEqual(models.QueryStatus.SCHEDULED, result2[u'status'])
+
+ time.sleep(2)
+
+ # Case 1.
+ query1 = self.get_query_by_id(result1[u'query_id'])
+ df1 = pd.read_sql_query(query1.select_sql, con=eng)
+ self.assertEqual(models.QueryStatus.FINISHED, query1.status)
+ self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records'))
+ self.assertEqual(models.QueryStatus.FINISHED, query1.status)
+ self.assertTrue("SELECT * \nFROM tmp_async_1" in query1.select_sql)
+ self.assertTrue("LIMIT 666" in query1.select_sql)
+ self.assertEqual(
+ "CREATE TABLE tmp_async_1 AS SELECT name FROM ab_role "
+ "WHERE name='Admin'", query1.executed_sql)
+ self.assertEqual(sql_where, query1.sql)
+ if eng.name != 'sqlite':
+ self.assertEqual(1, query1.rows)
+ self.assertEqual(666, query1.limit)
+ self.assertEqual(False, query1.limit_used)
+ self.assertEqual(True, query1.select_as_cta)
+ self.assertEqual(True, query1.select_as_cta_used)
+
+ # Case 2.
+ query2 = self.get_query_by_id(result2[u'query_id'])
+ self.assertEqual(models.QueryStatus.FINISHED, query2.status)
+ self.assertIsNone(query2.select_sql)
+ self.assertEqual(insert_query, query2.executed_sql)
+ self.assertEqual(insert_query, query2.sql)
+ if eng.name != 'sqlite':
+ self.assertEqual(1, query2.rows)
+ self.assertEqual(666, query2.limit)
+ self.assertEqual(False, query2.limit_used)
+ self.assertEqual(False, query2.select_as_cta)
+ self.assertEqual(False, query2.select_as_cta_used)
if __name__ == '__main__':
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 48d26c16e96..ed33c611e0f 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -321,7 +321,7 @@ class CoreTests(CaravelTestCase):
)
resp = self.client.post(
'/caravel/sql_json/',
- data=dict(database_id=dbid, sql=sql),
+ data=dict(database_id=dbid, sql=sql, select_as_create_as=False),
)
self.logout()
return json.loads(resp.data.decode('utf-8'))
@@ -340,9 +340,9 @@ class CoreTests(CaravelTestCase):
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
- .join(ab_models.ViewMenu)
- .filter(ab_models.ViewMenu.name == '[main].(id:1)')
- .first()
+ .join(ab_models.ViewMenu)
+ .filter(ab_models.ViewMenu.name == '[main].(id:1)')
+ .first()
)
astronaut = sm.add_role("Astronaut")
sm.add_permission_role(astronaut, main_db_permission_view)