Merge pull request #5023 from timifasubaa/fix_sqllab_commit

[sqllab] force limit queries only when there is no existing limit
This commit is contained in:
timifasubaa
2018-05-31 11:12:46 -07:00
committed by GitHub
5 changed files with 66 additions and 8 deletions

View File

@@ -196,10 +196,9 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue('FROM tmp_async_1' in query.select_sql)
self.assertTrue('LIMIT 666' in query.select_sql)
self.assertEqual(
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
"WHERE name='Admin'", query.executed_sql)
"WHERE name='Admin' LIMIT 666", query.executed_sql)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)
self.assertEqual(666, query.limit)
@@ -207,6 +206,33 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
def test_run_async_query_with_lower_limit(self):
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
result = self.run_sql(
main_db.id, sql_where, '5', async='true', tmp_table='tmp_async_2',
cta='true')
assert result['query']['state'] in (
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)
time.sleep(1)
query = self.get_query_by_id(result['query']['serverId'])
df = pd.read_sql_query(query.select_sql, con=eng)
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records'))
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue('FROM tmp_async_2' in query.select_sql)
self.assertEqual(
'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role '
"WHERE name='Alpha' LIMIT 1", query.executed_sql)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)
self.assertEqual(1, query.limit)
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
@staticmethod
def de_unicode_dict(d):
def str_if_basestring(o):

View File

@@ -95,6 +95,19 @@ class DbEngineSpecsTestCase(SupersetTestCase):
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
self.assertEquals(expected_sql, limited)
def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
q0 = 'select * from table'
q1 = 'select * from mytable limit 10'
q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
q3 = 'select * from (select * from my_subquery limit 10);'
q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
def test_wrapped_query(self):
self.sql_limit_regex(
'SELECT * FROM a',