Make schema name for the CTA queries and limit configurable (#8867)

* Make schema name configurable

Fixing unit tests

Fix table quoting

Mypy

Split tests out for sqlite

Grant more permissions for mysql user

Postgres doesn't support if not exists

More logging

Commit for table creation

Priviliges for postgres

Update tests

Resolve comments

Lint

No limits for the CTA queries if configures

* CTA -> CTAS and dict -> {}

* Move database creation to the .travis file

* Black

* Move tweaks to travis db setup

* Remove left over version

* Address comments

* Quote table names in the CTAS queries

* Pass tmp_schema_name for the query execution

* Rebase alembic migration

* Switch to python3 mypy

* SQLLAB_CTA_SCHEMA_NAME_FUNC -> SQLLAB_CTAS_SCHEMA_NAME_FUNC

* Black
This commit is contained in:
Bogdan
2020-03-03 09:52:20 -08:00
committed by GitHub
parent 26e916e46b
commit 4e1fa95035
15 changed files with 342 additions and 66 deletions

View File

@@ -17,14 +17,20 @@
# isort:skip_file
"""Unit tests for Superset Celery worker"""
import datetime
import io
import json
import logging
import subprocess
import time
import unittest
import unittest.mock as mock
import flask
import sqlalchemy
from contextlib2 import contextmanager
from flask import current_app
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.test_app import app
from superset import db, sql_lab
@@ -38,11 +44,11 @@ from superset.utils.core import get_example_database
from .base_tests import SupersetTestCase
CELERY_SHORT_SLEEP_TIME = 2
CELERY_SLEEP_TIME = 5
class UtilityFunctionTests(SupersetTestCase):
# TODO(bkyryliuk): support more cases in CTA function.
def test_create_table_as(self):
q = ParsedQuery("SELECT * FROM outer_space;")
@@ -90,6 +96,9 @@ class AppContextTests(SupersetTestCase):
flask._app_ctx_stack.push(popped_app)
CTAS_SCHEMA_NAME = "sqllab_test_db"
class CeleryTestCase(SupersetTestCase):
def get_query_by_name(self, sql):
session = db.session
@@ -159,7 +168,6 @@ class CeleryTestCase(SupersetTestCase):
def test_run_sync_query_cta(self):
main_db = get_example_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = "tmp_async_22"
self.drop_table_if_exists(tmp_table_name, main_db)
@@ -172,11 +180,12 @@ class CeleryTestCase(SupersetTestCase):
query2 = self.get_query_by_id(result["query"]["serverId"])
# Check the data in the tmp table.
if backend != "postgresql":
# TODO This test won't work in Postgres
results = self.run_sql(db_id, query2.select_sql, "sdf2134")
self.assertEqual(results["status"], "success")
self.assertGreater(len(results["data"]), 0)
results = self.run_sql(db_id, query2.select_sql, "sdf2134")
self.assertEqual(results["status"], "success")
self.assertGreater(len(results["data"]), 0)
# cleanup tmp table
self.drop_table_if_exists(tmp_table_name, get_example_database())
def test_run_sync_query_cta_no_data(self):
main_db = get_example_database()
@@ -199,15 +208,89 @@ class CeleryTestCase(SupersetTestCase):
db.session.flush()
return self.run_sql(db_id, sql)
def test_run_async_query(self):
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
)
def test_run_sync_query_cta_config(self):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = "tmp_async_22"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
self.drop_table_if_exists(expected_full_table_name, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
result = self.run_sql(
db_id, sql_where, "cid2", tmp_table=tmp_table_name, cta=True
)
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
self.assertEqual([], result["data"])
self.assertEqual([], result["columns"])
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James'",
query.executed_sql,
)
self.assertEqual(
"SELECT *\n" f"FROM {expected_full_table_name}", query.select_sql
)
time.sleep(CELERY_SHORT_SLEEP_TIME)
results = self.run_sql(db_id, query.select_sql)
self.assertEqual(results["status"], "success")
self.drop_table_if_exists(expected_full_table_name, get_example_database())
@mock.patch(
"superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME
)
def test_run_async_query_cta_config(self):
main_db = get_example_database()
db_id = main_db.id
if main_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = "sqllab_test_table_async_1"
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{tmp_table_name}"
self.drop_table_if_exists(expected_full_table_name, main_db)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id,
sql_where,
"cid3",
async_=True,
tmp_table="sqllab_test_table_async_1",
cta=True,
)
db.session.close()
time.sleep(CELERY_SLEEP_TIME)
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {expected_full_table_name}" in query.select_sql)
self.assertEqual(
f"CREATE TABLE {expected_full_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.drop_table_if_exists(expected_full_table_name, get_example_database())
def test_run_async_cta_query(self):
main_db = get_example_database()
db_id = main_db.id
self.drop_table_if_exists("tmp_async_1", main_db)
table_name = "tmp_async_4"
self.drop_table_if_exists(table_name, main_db)
time.sleep(CELERY_SLEEP_TIME)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
db_id, sql_where, "cid4", async_=True, tmp_table="tmp_async_4", cta=True
)
db.session.close()
assert result["query"]["state"] in (
@@ -221,9 +304,9 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue("FROM tmp_async_1" in query.select_sql)
self.assertTrue(f"FROM {table_name}" in query.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_1 AS \n"
f"CREATE TABLE {table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
@@ -234,7 +317,7 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
def test_run_async_cta_query_with_lower_limit(self):
main_db = get_example_database()
db_id = main_db.id
tmp_table = "tmp_async_2"
@@ -242,7 +325,7 @@ class CeleryTestCase(SupersetTestCase):
sql_where = "SELECT name FROM birth_names LIMIT 1"
result = self.run_sql(
db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
db_id, sql_where, "id1", async_=True, tmp_table=tmp_table, cta=True
)
db.session.close()
assert result["query"]["state"] in (
@@ -255,14 +338,15 @@ class CeleryTestCase(SupersetTestCase):
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
self.assertIn(f"FROM {tmp_table}", query.select_sql)
self.assertEqual(
f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)
self.assertEqual(1, query.limit)
self.assertEqual(None, query.limit)
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
@@ -280,9 +364,12 @@ class CeleryTestCase(SupersetTestCase):
with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data:
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
results, db_engine_spec, False, True
)
(
data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True)
expand_data.assert_called_once()
self.assertIsInstance(data, list)
@@ -301,9 +388,12 @@ class CeleryTestCase(SupersetTestCase):
with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data:
data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
results, db_engine_spec, True
)
(
data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, True)
expand_data.assert_not_called()
self.assertIsInstance(data, bytes)
@@ -324,7 +414,12 @@ class CeleryTestCase(SupersetTestCase):
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING,
}
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
(
serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {
@@ -357,7 +452,12 @@ class CeleryTestCase(SupersetTestCase):
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING,
}
serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
(
serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {