mirror of
https://github.com/apache/superset.git
synced 2026-05-13 03:45:12 +00:00
Compare commits
9 Commits
fix/mcp-ex
...
sc-103393-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89f9ca7364 | ||
|
|
b89364531f | ||
|
|
95ad8b81c7 | ||
|
|
6846244d8c | ||
|
|
c08706e460 | ||
|
|
5f40f5f1bb | ||
|
|
194831d685 | ||
|
|
aaad620cb3 | ||
|
|
6c00ee2eec |
@@ -67,7 +67,7 @@ export const renameOperator: PostProcessingFactory<PostProcessingRename> = (
|
||||
[...metricOffsetMap.entries()].forEach(
|
||||
([metricWithOffset, metricOnly]) => {
|
||||
const offsetLabel = timeOffsets.find(offset =>
|
||||
metricWithOffset.includes(offset),
|
||||
metricWithOffset.endsWith(`${TIME_COMPARISON_SEPARATOR}${offset}`),
|
||||
);
|
||||
renamePairs.push([
|
||||
formData.comparison_type === ComparisonType.Values
|
||||
|
||||
@@ -26,7 +26,7 @@ export const getTimeOffset = (
|
||||
timeCompare.find(
|
||||
timeOffset =>
|
||||
// offset is represented as <offset>, group by list
|
||||
series.name.includes(`${timeOffset},`) ||
|
||||
series.name.startsWith(`${timeOffset},`) ||
|
||||
// offset is represented as <metric>__<offset>
|
||||
series.name.includes(`__${timeOffset}`) ||
|
||||
// offset is represented as <metric>, <offset>
|
||||
@@ -50,7 +50,9 @@ export const getOriginalSeries = (
|
||||
// offset in the middle: <metric>, <offset>, <dimension>
|
||||
result = result.replace(`, ${compare},`, ',');
|
||||
// offset at start: <offset>, <dimension>
|
||||
result = result.replace(`${compare},`, '');
|
||||
if (result.startsWith(`${compare},`)) {
|
||||
result = result.slice(`${compare},`.length);
|
||||
}
|
||||
// offset with double underscore: <metric>__<offset>
|
||||
result = result.replace(`__${compare}`, '');
|
||||
// offset at end: <metric>, <offset>
|
||||
|
||||
@@ -303,6 +303,30 @@ test('should add renameOperator if multiple metrics exist', () => {
|
||||
});
|
||||
});
|
||||
|
||||
test('should correctly match offsets that share a numeric prefix', () => {
|
||||
expect(
|
||||
renameOperator(
|
||||
{
|
||||
...formData,
|
||||
|
||||
comparison_type: ComparisonType.Values,
|
||||
time_compare: ['1 year ago', '11 year ago'],
|
||||
},
|
||||
queryObject,
|
||||
),
|
||||
).toEqual({
|
||||
operation: 'rename',
|
||||
options: {
|
||||
columns: {
|
||||
'count(*)__1 year ago': '1 year ago',
|
||||
'count(*)__11 year ago': '11 year ago',
|
||||
},
|
||||
inplace: true,
|
||||
level: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test('should remove renameOperator', () => {
|
||||
expect(
|
||||
renameOperator(
|
||||
|
||||
@@ -114,3 +114,26 @@ test('hasTimeOffset returns false when series name is not a string', () => {
|
||||
const timeCompare = ['1 year ago'];
|
||||
expect(hasTimeOffset(series, timeCompare)).toBe(false);
|
||||
});
|
||||
|
||||
test('getTimeOffset correctly matches offsets that share a numeric prefix', () => {
|
||||
const timeCompare = ['1 year ago', '11 year ago'];
|
||||
expect(
|
||||
getTimeOffset({ name: '11 year ago, Alexander' }, timeCompare),
|
||||
).toEqual('11 year ago');
|
||||
expect(getTimeOffset({ name: '1 year ago, Alexander' }, timeCompare)).toEqual(
|
||||
'1 year ago',
|
||||
);
|
||||
expect(getTimeOffset({ name: 'Births__11 year ago' }, timeCompare)).toEqual(
|
||||
'11 year ago',
|
||||
);
|
||||
});
|
||||
|
||||
test('getOriginalSeries correctly strips offsets that share a numeric prefix', () => {
|
||||
const timeCompare = ['1 year ago', '11 year ago'];
|
||||
expect(getOriginalSeries('11 year ago, Alexander', timeCompare)).toEqual(
|
||||
'Alexander',
|
||||
);
|
||||
expect(getOriginalSeries('1 year ago, Alexander', timeCompare)).toEqual(
|
||||
'Alexander',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -199,12 +199,8 @@ async def get_chart_data( # noqa: C901
|
||||
|
||||
if not chart:
|
||||
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
|
||||
safe_id = str(request.identifier)[:200]
|
||||
return ChartError(
|
||||
error=(
|
||||
f"No chart found with identifier: {safe_id}."
|
||||
" Use list_charts to get valid chart IDs."
|
||||
),
|
||||
error=f"No chart found with identifier: {request.identifier}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
|
||||
@@ -1192,22 +1192,8 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
|
||||
if not chart:
|
||||
await ctx.warning("Chart not found: identifier=%s" % (request.identifier,))
|
||||
safe_id = str(request.identifier)[:200]
|
||||
is_form_data_key = (
|
||||
isinstance(request.identifier, str)
|
||||
and len(request.identifier) > 8
|
||||
and not request.identifier.isdigit()
|
||||
)
|
||||
if is_form_data_key:
|
||||
recovery = (
|
||||
"If using a form_data_key, it may have expired — "
|
||||
"use generate_explore_link to get a fresh key, "
|
||||
"or use list_charts to find a saved chart by ID."
|
||||
)
|
||||
else:
|
||||
recovery = "Use list_charts to get valid chart IDs."
|
||||
return ChartError(
|
||||
error=f"No chart found with identifier: {safe_id}. {recovery}",
|
||||
error=f"No chart found with identifier: {request.identifier}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
|
||||
@@ -337,18 +337,17 @@ async def update_chart( # noqa: C901
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
if not chart:
|
||||
safe_id = str(request.identifier)[:200]
|
||||
not_found_msg = (
|
||||
f"No chart found with identifier: {safe_id}."
|
||||
" Use list_charts to get valid chart IDs."
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": {
|
||||
"error_type": "NotFound",
|
||||
"message": not_found_msg,
|
||||
"details": not_found_msg,
|
||||
"message": (
|
||||
f"No chart found with identifier: {request.identifier}"
|
||||
),
|
||||
"details": (
|
||||
f"No chart found with identifier: {request.identifier}"
|
||||
),
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
|
||||
@@ -334,10 +334,7 @@ def _find_and_authorize_dashboard(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=(
|
||||
f"Dashboard with ID {dashboard_id} not found."
|
||||
" Use list_dashboards to get valid dashboard IDs."
|
||||
),
|
||||
error=f"Dashboard with ID {dashboard_id} not found",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -395,10 +392,7 @@ def add_chart_to_existing_dashboard(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=(
|
||||
f"Chart with ID {request.chart_id} not found."
|
||||
" Use list_charts to get valid chart IDs."
|
||||
),
|
||||
error=f"Chart with ID {request.chart_id} not found",
|
||||
)
|
||||
|
||||
# Validate dataset access for the chart.
|
||||
|
||||
@@ -230,10 +230,7 @@ def generate_dashboard( # noqa: C901
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=(
|
||||
f"Charts not found: {list(missing_chart_ids)}."
|
||||
" Use list_charts to get valid chart IDs."
|
||||
),
|
||||
error=f"Charts not found: {list(missing_chart_ids)}",
|
||||
)
|
||||
|
||||
# Validate dataset access for each chart.
|
||||
|
||||
@@ -183,10 +183,7 @@ async def query_dataset( # noqa: C901
|
||||
if dataset is None:
|
||||
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
|
||||
return DatasetError.create(
|
||||
error=(
|
||||
f"No dataset found with identifier: {request.dataset_id}."
|
||||
" Use list_datasets to get valid dataset IDs."
|
||||
),
|
||||
error=f"No dataset found with identifier: {request.dataset_id}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
|
||||
@@ -100,10 +100,7 @@ async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlRes
|
||||
)
|
||||
return ExecuteSqlResponse(
|
||||
success=False,
|
||||
error=(
|
||||
f"Database with ID {request.database_id} not found."
|
||||
" Use list_databases to get valid database IDs."
|
||||
),
|
||||
error=f"Database with ID {request.database_id} not found",
|
||||
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@@ -103,8 +103,7 @@ def open_sql_lab_with_context(
|
||||
database = DatabaseDAO.find_by_id(request.database_connection_id)
|
||||
if not database:
|
||||
error_message = (
|
||||
f"Database with ID {request.database_connection_id} not found."
|
||||
" Use list_databases to get valid database IDs."
|
||||
f"Database with ID {request.database_connection_id} not found"
|
||||
)
|
||||
return _sanitize_sql_lab_response_for_llm_context(
|
||||
SqlLabResponse(
|
||||
|
||||
@@ -1887,10 +1887,26 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
time_grain
|
||||
)
|
||||
|
||||
if join_column_producer and not time_grain:
|
||||
raise QueryObjectValidationError(
|
||||
_("Time Grain must be specified when using Time Shift.")
|
||||
if not time_grain:
|
||||
has_temporal_join_key = any(
|
||||
pd.api.types.is_datetime64_any_dtype(df[key])
|
||||
for key in join_keys
|
||||
if key in df.columns
|
||||
)
|
||||
if has_temporal_join_key:
|
||||
has_relative_offset = any(
|
||||
not (
|
||||
self.is_valid_date_range(offset)
|
||||
and feature_flag_manager.is_feature_enabled(
|
||||
"DATE_RANGE_TIMESHIFTS_ENABLED"
|
||||
)
|
||||
)
|
||||
for offset in offset_dfs
|
||||
)
|
||||
if has_relative_offset:
|
||||
raise QueryObjectValidationError(
|
||||
_("Time Grain must be specified when using Time Comparison.")
|
||||
)
|
||||
|
||||
for offset, offset_df in offset_dfs.items():
|
||||
is_date_range_offset = self.is_valid_date_range(
|
||||
@@ -3084,14 +3100,6 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
||||
sqla_col = self.convert_tbl_column_to_sqla_col(
|
||||
tbl_column=col_obj, template_processor=template_processor
|
||||
)
|
||||
# Parenthesize expression-based columns to prevent operator
|
||||
# precedence issues (e.g. OR in a calculated column breaking
|
||||
# surrounding AND filters). Same pattern as extras.where
|
||||
# wrapping added in PR #38183.
|
||||
if sqla_col is not None and (
|
||||
(col_obj and col_obj.expression) or is_adhoc_column(flt_col)
|
||||
):
|
||||
sqla_col = Grouping(sqla_col)
|
||||
col_type = col_obj.type if col_obj else None
|
||||
col_spec = db_engine_spec.get_column_spec(native_type=col_type)
|
||||
is_list_target = op in (
|
||||
|
||||
@@ -1782,9 +1782,9 @@ def extract_dataframe_dtypes(
|
||||
columns_by_name[column.column_name] = column
|
||||
|
||||
generic_types: list[GenericDataType] = []
|
||||
for column in df.columns:
|
||||
for i, column in enumerate(df.columns):
|
||||
column_object = columns_by_name.get(str(column))
|
||||
series = df[column]
|
||||
series = df.iloc[:, i]
|
||||
inferred_type: str = ""
|
||||
if series.isna().all():
|
||||
sql_type: Optional[str] = ""
|
||||
|
||||
@@ -24,7 +24,7 @@ from flask_babel import gettext as _
|
||||
from pandas import DataFrame, MultiIndex
|
||||
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingContributionOrientation
|
||||
from superset.utils.core import PostProcessingContributionOrientation, TIME_COMPARISON
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ def get_column_groups(
|
||||
time_shift = None
|
||||
if time_shifts and isinstance(col_0, str):
|
||||
for ts in time_shifts:
|
||||
if col_0.endswith(ts):
|
||||
if col_0.endswith(TIME_COMPARISON + ts):
|
||||
time_shift = ts
|
||||
break
|
||||
if time_shift is not None:
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
from pandas import DataFrame, Series, Timestamp
|
||||
from pandas.testing import assert_frame_equal
|
||||
from pytest import fixture, mark # noqa: PT013
|
||||
@@ -23,6 +24,7 @@ from superset.common.query_context import QueryContext
|
||||
from superset.common.query_context_processor import QueryContextProcessor
|
||||
from superset.connectors.sqla.models import BaseDatasource
|
||||
from superset.constants import TimeGrain
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.models.helpers import ExploreMixin
|
||||
|
||||
# Create processor and bind ExploreMixin methods to datasource
|
||||
@@ -242,3 +244,55 @@ def test_join_offset_dfs_totals_query_no_dimensions():
|
||||
)
|
||||
|
||||
assert_frame_equal(expected, result)
|
||||
|
||||
|
||||
def test_join_offset_dfs_raises_without_time_grain():
|
||||
"""Time comparison with relative offsets requires a time grain."""
|
||||
df = DataFrame({"ds": [Timestamp("2021-01-01")], "D": [1]})
|
||||
offset_df = DataFrame({"ds": [Timestamp("2021-02-01")], "B": [5]})
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
with pytest.raises(
|
||||
QueryObjectValidationError, match="Time Grain must be specified"
|
||||
):
|
||||
query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["ds"]
|
||||
)
|
||||
|
||||
|
||||
def test_join_offset_dfs_allows_non_temporal_join_without_time_grain():
|
||||
"""Time comparison without time grain is valid when join keys are non-temporal."""
|
||||
df = DataFrame({"country": ["US", "UK"], "metric": [10, 20]})
|
||||
offset_df = DataFrame({"country": ["US", "UK"], "metric__1 year ago": [8, 15]})
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
result = query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["country"]
|
||||
)
|
||||
assert "metric__1 year ago" in result.columns
|
||||
|
||||
|
||||
def test_join_offset_dfs_raises_when_temporal_key_not_first():
|
||||
"""Temporal join key detection works even when it's not the first key."""
|
||||
df = DataFrame(
|
||||
{
|
||||
"country": ["US", "UK"],
|
||||
"ds": [Timestamp("2021-01-01"), Timestamp("2021-02-01")],
|
||||
"D": [1, 2],
|
||||
}
|
||||
)
|
||||
offset_df = DataFrame(
|
||||
{
|
||||
"country": ["US", "UK"],
|
||||
"ds": [Timestamp("2021-03-01"), Timestamp("2021-04-01")],
|
||||
"B": [5, 6],
|
||||
}
|
||||
)
|
||||
offset_dfs = {"1 year ago": offset_df}
|
||||
|
||||
with pytest.raises(
|
||||
QueryObjectValidationError, match="Time Grain must be specified"
|
||||
):
|
||||
query_context_processor.join_offset_dfs(
|
||||
df, offset_dfs, time_grain=None, join_keys=["country", "ds"]
|
||||
)
|
||||
|
||||
@@ -298,8 +298,7 @@ class TestOpenSqlLabWithContext:
|
||||
field_path=("title",),
|
||||
)
|
||||
assert response.error == sanitize_for_llm_context(
|
||||
"Database with ID 404 not found."
|
||||
" Use list_databases to get valid database IDs.",
|
||||
"Database with ID 404 not found",
|
||||
field_path=("error",),
|
||||
)
|
||||
finally:
|
||||
|
||||
@@ -1937,235 +1937,6 @@ def test_extras_having_is_parenthesized(
|
||||
)
|
||||
|
||||
|
||||
def test_calculated_column_filter_is_parenthesized(
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that calculated column expressions containing OR are wrapped in
|
||||
parentheses when used in WHERE filters.
|
||||
|
||||
Without parentheses, a calculated column expression like
|
||||
``status = 'active' OR status = 'pending'`` combined with other filters
|
||||
via AND would produce unexpected evaluation order due to SQL operator
|
||||
precedence (AND binds tighter than OR), potentially dropping time range
|
||||
and other filters. Same class of bug as fixed in PR #38183 for
|
||||
extras.where/having, but on the calculated column filter path.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
TableColumn(
|
||||
column_name="is_active",
|
||||
expression="status = 'active' OR status = 'pending'",
|
||||
type="BOOLEAN",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
sqla_query = table.get_sqla_query(
|
||||
columns=["a"],
|
||||
filter=[
|
||||
{
|
||||
"col": "is_active",
|
||||
"op": "IS TRUE",
|
||||
"val": None,
|
||||
},
|
||||
],
|
||||
extras={},
|
||||
is_timeseries=False,
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
sql = str(
|
||||
sqla_query.sqla_query.compile(
|
||||
dialect=engine.dialect,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "(status = 'active' OR status = 'pending')" in sql, (
|
||||
f"Calculated column expression should be wrapped in parentheses. "
|
||||
f"Generated SQL: {sql}"
|
||||
)
|
||||
|
||||
|
||||
def test_calculated_column_nested_or_and_is_parenthesized(
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that calculated column expressions with nested OR/AND combinations
|
||||
are correctly parenthesized as a single unit in WHERE filters.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
TableColumn(
|
||||
column_name="is_target",
|
||||
expression=(
|
||||
"(status = 'active' AND region = 'US') "
|
||||
"OR (status = 'pending' AND region = 'EU')"
|
||||
),
|
||||
type="BOOLEAN",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
sqla_query = table.get_sqla_query(
|
||||
columns=["a"],
|
||||
filter=[
|
||||
{
|
||||
"col": "is_target",
|
||||
"op": "IS TRUE",
|
||||
"val": None,
|
||||
},
|
||||
],
|
||||
extras={},
|
||||
is_timeseries=False,
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
sql = str(
|
||||
sqla_query.sqla_query.compile(
|
||||
dialect=engine.dialect,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
"((status = 'active' AND region = 'US') "
|
||||
"OR (status = 'pending' AND region = 'EU'))"
|
||||
) in sql, (
|
||||
f"Nested OR/AND expression should be wrapped in parentheses. "
|
||||
f"Generated SQL: {sql}"
|
||||
)
|
||||
|
||||
|
||||
def test_calculated_column_non_boolean_filter_is_parenthesized(
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that non-boolean calculated column expressions are parenthesized
|
||||
when used with IN filters.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
TableColumn(
|
||||
column_name="full_name",
|
||||
expression="first_name || ' ' || last_name",
|
||||
type="TEXT",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
sqla_query = table.get_sqla_query(
|
||||
columns=["a"],
|
||||
filter=[
|
||||
{
|
||||
"col": "full_name",
|
||||
"op": "IN",
|
||||
"val": ["John Doe", "Jane Doe"],
|
||||
},
|
||||
],
|
||||
extras={},
|
||||
is_timeseries=False,
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
sql = str(
|
||||
sqla_query.sqla_query.compile(
|
||||
dialect=engine.dialect,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "(first_name || ' ' || last_name)" in sql, (
|
||||
f"Non-boolean calculated column should be wrapped in parentheses. "
|
||||
f"Generated SQL: {sql}"
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_calculated_columns_each_parenthesized(
|
||||
database: Database,
|
||||
) -> None:
|
||||
"""
|
||||
Test that multiple calculated columns used as filters are each
|
||||
independently wrapped in parentheses.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
|
||||
table = SqlaTable(
|
||||
database=database,
|
||||
schema=None,
|
||||
table_name="t",
|
||||
columns=[
|
||||
TableColumn(column_name="a", type="INTEGER"),
|
||||
TableColumn(
|
||||
column_name="is_active",
|
||||
expression="status = 'active' OR status = 'pending'",
|
||||
type="BOOLEAN",
|
||||
),
|
||||
TableColumn(
|
||||
column_name="is_premium",
|
||||
expression="tier = 'gold' OR tier = 'platinum'",
|
||||
type="BOOLEAN",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
sqla_query = table.get_sqla_query(
|
||||
columns=["a"],
|
||||
filter=[
|
||||
{
|
||||
"col": "is_active",
|
||||
"op": "IS TRUE",
|
||||
"val": None,
|
||||
},
|
||||
{
|
||||
"col": "is_premium",
|
||||
"op": "IS TRUE",
|
||||
"val": None,
|
||||
},
|
||||
],
|
||||
extras={},
|
||||
is_timeseries=False,
|
||||
metrics=[],
|
||||
)
|
||||
|
||||
with database.get_sqla_engine() as engine:
|
||||
sql = str(
|
||||
sqla_query.sqla_query.compile(
|
||||
dialect=engine.dialect,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "(status = 'active' OR status = 'pending')" in sql, (
|
||||
f"First calculated column should be parenthesized. Generated SQL: {sql}"
|
||||
)
|
||||
assert "(tier = 'gold' OR tier = 'platinum')" in sql, (
|
||||
f"Second calculated column should be parenthesized. Generated SQL: {sql}"
|
||||
)
|
||||
|
||||
|
||||
def _run_probe(
|
||||
database: Database,
|
||||
type_probe_needs_row: bool = False,
|
||||
|
||||
@@ -124,3 +124,36 @@ def test_contribution_with_time_shift_columns():
|
||||
assert_array_equal(processed_df["a__1 week ago"].tolist(), [0.5, 0.5])
|
||||
assert_array_equal(processed_df["b__1 week ago"].tolist(), [0.25, 0.25])
|
||||
assert_array_equal(processed_df["c__1 week ago"].tolist(), [0.25, 0.25])
|
||||
|
||||
|
||||
def test_contribution_with_numeric_prefix_time_shifts():
|
||||
"""Time shifts like '2 weeks ago' and '22 weeks ago' share a numeric suffix;
|
||||
columns must be grouped by their exact offset, not by suffix matching."""
|
||||
df = DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [
|
||||
datetime(2020, 7, 16, 14, 49),
|
||||
datetime(2020, 7, 16, 14, 50),
|
||||
],
|
||||
"a": [3, 6],
|
||||
"b": [6, 3],
|
||||
"a__2 weeks ago": [1, 1],
|
||||
"b__2 weeks ago": [1, 1],
|
||||
"a__22 weeks ago": [2, 4],
|
||||
"b__22 weeks ago": [4, 2],
|
||||
}
|
||||
)
|
||||
processed_df = contribution(
|
||||
df,
|
||||
orientation=PostProcessingContributionOrientation.ROW,
|
||||
time_shifts=["2 weeks ago", "22 weeks ago"],
|
||||
)
|
||||
# Non-time-shift columns: a=3,b=6 -> a=1/3, b=2/3; a=6,b=3 -> a=2/3, b=1/3
|
||||
assert_array_equal(processed_df["a"].tolist(), [1 / 3, 2 / 3])
|
||||
assert_array_equal(processed_df["b"].tolist(), [2 / 3, 1 / 3])
|
||||
# "2 weeks ago" group: a=1,b=1 -> 0.5,0.5 each row
|
||||
assert_array_equal(processed_df["a__2 weeks ago"].tolist(), [0.5, 0.5])
|
||||
assert_array_equal(processed_df["b__2 weeks ago"].tolist(), [0.5, 0.5])
|
||||
# "22 weeks ago" group: a=2,b=4 -> 1/3,2/3; a=4,b=2 -> 2/3,1/3
|
||||
assert_array_equal(processed_df["a__22 weeks ago"].tolist(), [1 / 3, 2 / 3])
|
||||
assert_array_equal(processed_df["b__22 weeks ago"].tolist(), [2 / 3, 1 / 3])
|
||||
|
||||
@@ -31,6 +31,7 @@ from superset.utils.core import (
|
||||
cast_to_boolean,
|
||||
check_is_safe_zip,
|
||||
DateColumn,
|
||||
extract_dataframe_dtypes,
|
||||
FilterOperator,
|
||||
generic_find_constraint_name,
|
||||
generic_find_fk_constraint_name,
|
||||
@@ -647,8 +648,9 @@ def test_get_user_agent(mocker: MockerFixture, app_context: None) -> None:
|
||||
|
||||
@with_config(
|
||||
{
|
||||
"USER_AGENT_FUNC": lambda database,
|
||||
source: f"{database.database_name} {source.name}"
|
||||
"USER_AGENT_FUNC": lambda database, source: (
|
||||
f"{database.database_name} {source.name}"
|
||||
)
|
||||
}
|
||||
)
|
||||
def test_get_user_agent_custom(mocker: MockerFixture, app_context: None) -> None:
|
||||
@@ -1730,3 +1732,10 @@ def test_markdown_with_markup_wrap() -> None:
|
||||
|
||||
assert isinstance(result, Markup)
|
||||
assert "<strong>bold</strong>" in str(result)
|
||||
|
||||
|
||||
def test_extract_dataframe_dtypes_with_duplicate_columns() -> None:
|
||||
"""extract_dataframe_dtypes should not crash on duplicate column names."""
|
||||
df = pd.DataFrame([[1, 2, 3]], columns=["a", "b", "a"])
|
||||
result = extract_dataframe_dtypes(df)
|
||||
assert len(result) == 3
|
||||
|
||||
Reference in New Issue
Block a user