diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 31f475f405d..7e3811a953f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -60,7 +60,14 @@ from sqlalchemy import ( ) from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.schema import UniqueConstraint -from sqlalchemy.sql import column, ColumnElement, literal_column, table, text +from sqlalchemy.sql import ( + column, + ColumnElement, + literal_column, + quoted_name, + table, + text, +) from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause from sqlalchemy.sql.selectable import Alias, TableClause @@ -912,16 +919,25 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at self, sqla_col: Column, label: Optional[str] = None ) -> Column: """Takes a sqlalchemy column object and adds label info if supported by engine. + also adds quotes to the column if engine is configured for quotes. :param sqla_col: sqlalchemy column instance :param label: alias/label that column is expected to have :return: either a sql alchemy column or label instance if supported by engine """ label_expected = label or sqla_col.name db_engine_spec = self.db_engine_spec + + # add quotes to column + if db_engine_spec.force_column_alias_quotes: + sqla_col = column( + quoted_name(sqla_col.name, True), sqla_col.type, sqla_col.is_literal + ) + # add quotes to tables if db_engine_spec.allows_alias_in_select: label = db_engine_spec.make_label_compatible(label_expected) sqla_col = sqla_col.label(label) + sqla_col.key = label_expected return sqla_col diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index c609e6ba383..2c561600c47 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -25,6 +25,9 @@ import json import logging from typing import Dict, List from urllib.parse import quote + +from sqlalchemy.sql import column, quoted_name, literal_column +from sqlalchemy import select from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, ) @@ -40,7 +43,7 @@ import pandas as pd import sqlalchemy as sqla from sqlalchemy.exc import SQLAlchemyError from superset.models.cache import CacheKey -from superset.utils.core import get_example_database +from superset.utils.core import get_example_database, get_or_create_db from tests.integration_tests.conftest import with_feature_flags from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_with_slice, @@ -898,6 +901,58 @@ class TestCore(SupersetTestCase): rendered_query = str(table.get_from_clause()) self.assertEqual(clean_query, rendered_query) + def test_make_column_compatible(self): + """ + DB Eng Specs: Make column compatible + """ + + # with force_column_alias_quotes enabled + snowflake_database = get_or_create_db("snowflake", "snowflake://") + + table = SqlaTable( + table_name="test_columns_with_alias_quotes", database=snowflake_database, + ) + + col = table.make_sqla_column_compatible(column("foo")) + s = select([col]) + self.assertEqual(str(s), 'SELECT "foo" AS "foo"') + + # with literal_column + table = SqlaTable( + table_name="test_columns_with_alias_quotes_on_literal_column", + database=snowflake_database, + ) + + col = table.make_sqla_column_compatible(literal_column("foo")) + s = select([col]) + self.assertEqual(str(s), 'SELECT foo AS "foo"') + + # with force_column_alias_quotes NOT enabled + postgres_database = get_or_create_db("postgresql", "postgresql://") + + table = SqlaTable( + table_name="test_columns_with_no_quotes", database=postgres_database, + ) + + col = table.make_sqla_column_compatible(column("foo")) + s = select([col]) + self.assertEqual(str(s), "SELECT foo AS foo") + + # with literal_column + table = SqlaTable( + table_name="test_columns_with_no_quotes_on_literal_column", + database=postgres_database, + ) + + col = table.make_sqla_column_compatible(literal_column("foo")) + s = select([col]) + self.assertEqual(str(s), "SELECT foo AS foo") + + # cleanup + db.session.delete(snowflake_database) + db.session.delete(postgres_database) + db.session.commit() + def test_slice_payload_no_datasource(self): self.login(username="admin") data = self.get_json_resp("/superset/explore_json/", raise_on_error=False) diff --git a/tests/integration_tests/db_engine_specs/snowflake_tests.py b/tests/integration_tests/db_engine_specs/snowflake_tests.py index 1be0de08e30..75fe7fd92e3 100644 --- a/tests/integration_tests/db_engine_specs/snowflake_tests.py +++ b/tests/integration_tests/db_engine_specs/snowflake_tests.py @@ -16,6 +16,8 @@ # under the License. import json +from sqlalchemy import column + from superset.db_engine_specs.snowflake import SnowflakeEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.core import Database @@ -23,6 +25,20 @@ from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec class TestSnowflakeDbEngineSpec(TestDbEngineSpec): + def test_snowflake_sqla_column_label(self): + """ + DB Eng Specs (snowflake): Test column label + """ + test_cases = { + "Col": "Col", + "SUM(x)": "SUM(x)", + "SUM[x]": "SUM[x]", + "12345_col": "12345_col", + } + for original, expected in test_cases.items(): + actual = SnowflakeEngineSpec.make_label_compatible(column(original).name) + self.assertEqual(actual, expected) + def test_convert_dttm(self): dttm = self.get_dttm()