fix: customize column description limit size in db_engine_spec (#34808)

This commit is contained in:
JUST.in DO IT
2025-08-22 10:00:39 -07:00
committed by GitHub
parent 0a45a89786
commit 75af53dc3d
9 changed files with 67 additions and 20 deletions

View File

@@ -148,11 +148,12 @@ def get_columns_description(
try:
with database.get_raw_connection(catalog=catalog, schema=schema) as conn:
cursor = conn.cursor()
query = database.apply_limit_to_sql(query, limit=1)
limit = database.get_column_description_limit_size()
query = database.apply_limit_to_sql(query, limit=limit)
mutated_query = database.mutate_sql_based_on_config(query)
cursor.execute(mutated_query)
db_engine_spec.execute(cursor, mutated_query, database)
result = db_engine_spec.fetch_data(cursor, limit=1)
result = db_engine_spec.fetch_data(cursor, limit=limit)
result_set = SupersetResultSet(result, cursor.description, db_engine_spec)
return result_set.columns
except Exception as ex:

View File

@@ -1978,6 +1978,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
return []
@classmethod
def get_column_description_limit_size(cls) -> int:
"""
Get a minimum limit size for the sample SELECT column query
to fetch the column metadata.
:return: A number of limit size
"""
return 1
@staticmethod
def pyodbc_rows_to_tuples(data: list[Any]) -> list[tuple[Any, ...]]:
"""

View File

@@ -852,6 +852,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
return script.format()
def get_column_description_limit_size(self) -> int:
return self.db_engine_spec.get_column_description_limit_size()
def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri

View File

@@ -135,21 +135,21 @@ class SupersetResultSet:
if data and (not isinstance(data, list) or not isinstance(data[0], tuple)):
data = [tuple(row) for row in data]
array = np.array(data, dtype=numpy_dtype)
if array.size > 0:
for column in column_names:
try:
pa_data.append(pa.array(array[column].tolist()))
except (
pa.lib.ArrowInvalid,
pa.lib.ArrowTypeError,
pa.lib.ArrowNotImplementedError,
ValueError,
TypeError, # this is super hackey,
# https://issues.apache.org/jira/browse/ARROW-7855
):
# attempt serialization of values as strings
stringified_arr = stringify_values(array[column])
pa_data.append(pa.array(stringified_arr.tolist()))
for column in column_names:
try:
pa_data.append(pa.array(array[column].tolist()))
except (
pa.lib.ArrowInvalid,
pa.lib.ArrowTypeError,
pa.lib.ArrowNotImplementedError,
ValueError,
TypeError, # this is super hackey,
# https://issues.apache.org/jira/browse/ARROW-7855
):
# attempt serialization of values as strings
stringified_arr = stringify_values(array[column])
pa_data.append(pa.array(stringified_arr.tolist()))
if pa_data: # pylint: disable=too-many-nested-blocks
for i, column in enumerate(column_names):

View File

@@ -217,7 +217,8 @@ def test_run_sync_query_cta_no_data(test_client):
sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
result = run_sql(test_client, sql_empty_result)
assert QueryStatus.SUCCESS == result["query"]["state"]
assert ([], []) == (result["data"], result["columns"])
assert [] == result["data"]
assert len(result["columns"]) > 0
query = get_query_by_id(result["query"]["serverId"])
assert QueryStatus.SUCCESS == query.status

View File

@@ -718,7 +718,14 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
},
)
assert rv.status_code == 200
assert rv.json["result"]["colnames"] == []
assert rv.json["result"]["colnames"] == [
"col1",
"col2",
"col3",
"col4",
"col5",
"col6",
]
assert rv.json["result"]["rowcount"] == 0

View File

@@ -116,6 +116,10 @@ class SupersetTestCases(SupersetTestCase):
)
assert base_result_expected == base_result
def test_get_column_description_limit_size(self):
base_result = BaseEngineSpec.get_column_description_limit_size()
assert base_result == 1
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_column_datatype_to_string(self):
example_db = get_example_database()

View File

@@ -308,4 +308,4 @@ class TestSupersetResultSet(SupersetTestCase):
("emptytwo", "int", None, None, None, None, True),
]
results = SupersetResultSet(data, cursor_descr, BaseEngineSpec)
assert results.columns == []
assert len(results.columns) == 2

View File

@@ -26,6 +26,7 @@ from pytest_mock import MockerFixture
from superset.db_engine_specs.base import BaseEngineSpec
from superset.result_set import stringify_values, SupersetResultSet
from superset.superset_typing import DbapiResult
def test_column_names_as_bytes() -> None:
@@ -164,3 +165,23 @@ def test_timezone_series(mocker: MockerFixture) -> None:
[pd.Timestamp("2023-01-01 00:00:00+0000", tz="UTC")]
]
logger.exception.assert_not_called()
def test_get_column_description_from_empty_data_using_cursor_description(
mocker: MockerFixture,
) -> None:
"""
Test that we can handle get_column_decription from the cursor description
when data is empty
"""
logger = mocker.patch("superset.result_set.logger")
data: DbapiResult = []
description = [(b"__time", "datetime", None, None, None, None, 1, 0, 255)]
result_set = SupersetResultSet(
data,
description, # type: ignore
BaseEngineSpec,
)
assert any(col.get("column_name") == "__time" for col in result_set.columns)
logger.exception.assert_not_called()