mirror of
https://github.com/apache/superset.git
synced 2026-05-03 06:54:19 +00:00
Compare commits
8 Commits
upgrade-sq
...
v2021.36.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c9545a41e6 | ||
|
|
1a58188232 | ||
|
|
3014da13f1 | ||
|
|
3187c66c3a | ||
|
|
5cf4d5bb4b | ||
|
|
5d8e1f5d3a | ||
|
|
7d35a91642 | ||
|
|
cc821bb747 |
@@ -25,6 +25,9 @@ assists people when migrating to a new version.
|
||||
## Next
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`.
|
||||
|
||||
### Potential Downtime
|
||||
### Deprecations
|
||||
### Other
|
||||
|
||||
@@ -613,7 +613,7 @@ describe('async actions', () => {
|
||||
|
||||
describe('queryEditorSetSql', () => {
|
||||
describe('with backend persistence flag on', () => {
|
||||
it('does not update the tab state in the backend', () => {
|
||||
it('updates the tab state in the backend', () => {
|
||||
expect.assertions(2);
|
||||
|
||||
const sql = 'SELECT * ';
|
||||
@@ -629,7 +629,7 @@ describe('async actions', () => {
|
||||
});
|
||||
});
|
||||
describe('with backend persistence flag off', () => {
|
||||
it('updates the tab state in the backend', () => {
|
||||
it('does not update the tab state in the backend', () => {
|
||||
const backendPersistenceOffMock = jest
|
||||
.spyOn(featureFlags, 'isFeatureEnabled')
|
||||
.mockImplementation(
|
||||
|
||||
@@ -949,6 +949,11 @@ export function queryEditorSetQueryLimit(queryEditor, queryLimit) {
|
||||
|
||||
export function queryEditorSetTemplateParams(queryEditor, templateParams) {
|
||||
return function (dispatch) {
|
||||
dispatch({
|
||||
type: QUERY_EDITOR_SET_TEMPLATE_PARAMS,
|
||||
queryEditor,
|
||||
templateParams,
|
||||
});
|
||||
const sync = isFeatureEnabled(FeatureFlag.SQLLAB_BACKEND_PERSISTENCE)
|
||||
? SupersetClient.put({
|
||||
endpoint: encodeURI(`/tabstateview/${queryEditor.id}`),
|
||||
@@ -956,24 +961,16 @@ export function queryEditorSetTemplateParams(queryEditor, templateParams) {
|
||||
})
|
||||
: Promise.resolve();
|
||||
|
||||
return sync
|
||||
.then(() =>
|
||||
dispatch({
|
||||
type: QUERY_EDITOR_SET_TEMPLATE_PARAMS,
|
||||
queryEditor,
|
||||
templateParams,
|
||||
}),
|
||||
)
|
||||
.catch(() =>
|
||||
dispatch(
|
||||
addDangerToast(
|
||||
t(
|
||||
'An error occurred while setting the tab template parameters. ' +
|
||||
'Please contact your administrator.',
|
||||
),
|
||||
return sync.catch(() =>
|
||||
dispatch(
|
||||
addDangerToast(
|
||||
t(
|
||||
'An error occurred while setting the tab template parameters. ' +
|
||||
'Please contact your administrator.',
|
||||
),
|
||||
),
|
||||
);
|
||||
),
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -25,3 +25,13 @@ export interface DatasourcePanelDndItem {
|
||||
value: DndItemValue;
|
||||
type: DndItemType;
|
||||
}
|
||||
|
||||
export function isDatasourcePanelDndItem(
|
||||
item: any,
|
||||
): item is DatasourcePanelDndItem {
|
||||
return item?.value && item?.type;
|
||||
}
|
||||
|
||||
export function isSavedMetric(item: any): item is Metric {
|
||||
return item?.metric_name;
|
||||
}
|
||||
|
||||
@@ -45,10 +45,12 @@ import AdhocMetric from 'src/explore/components/controls/MetricControl/AdhocMetr
|
||||
import {
|
||||
DatasourcePanelDndItem,
|
||||
DndItemValue,
|
||||
isSavedMetric,
|
||||
} from 'src/explore/components/DatasourcePanel/types';
|
||||
import { DndItemType } from 'src/explore/components/DndItemType';
|
||||
import { ControlComponentProps } from 'src/explore/components/Control';
|
||||
|
||||
const EMPTY_OBJECT = {};
|
||||
const DND_ACCEPTED_TYPES = [
|
||||
DndItemType.Column,
|
||||
DndItemType.Metric,
|
||||
@@ -78,7 +80,9 @@ export const DndFilterSelect = (props: DndFilterSelectProps) => {
|
||||
);
|
||||
const [partitionColumn, setPartitionColumn] = useState(undefined);
|
||||
const [newFilterPopoverVisible, setNewFilterPopoverVisible] = useState(false);
|
||||
const [droppedItem, setDroppedItem] = useState<DndItemValue | null>(null);
|
||||
const [droppedItem, setDroppedItem] = useState<
|
||||
DndItemValue | typeof EMPTY_OBJECT
|
||||
>({});
|
||||
|
||||
const optionsForSelect = (
|
||||
columns: ColumnMeta[],
|
||||
@@ -342,12 +346,12 @@ export const DndFilterSelect = (props: DndFilterSelectProps) => {
|
||||
);
|
||||
|
||||
const handleClickGhostButton = useCallback(() => {
|
||||
setDroppedItem(null);
|
||||
setDroppedItem({});
|
||||
togglePopover(true);
|
||||
}, [togglePopover]);
|
||||
|
||||
const adhocFilter = useMemo(() => {
|
||||
if (droppedItem?.metric_name) {
|
||||
if (isSavedMetric(droppedItem)) {
|
||||
return new AdhocFilter({
|
||||
expressionType: EXPRESSION_TYPES.SQL,
|
||||
clause: CLAUSES.HAVING,
|
||||
|
||||
@@ -34,7 +34,10 @@ import { usePrevious } from 'src/common/hooks/usePrevious';
|
||||
import AdhocMetric from 'src/explore/components/controls/MetricControl/AdhocMetric';
|
||||
import AdhocMetricPopoverTrigger from 'src/explore/components/controls/MetricControl/AdhocMetricPopoverTrigger';
|
||||
import MetricDefinitionValue from 'src/explore/components/controls/MetricControl/MetricDefinitionValue';
|
||||
import { DatasourcePanelDndItem } from 'src/explore/components/DatasourcePanel/types';
|
||||
import {
|
||||
DatasourcePanelDndItem,
|
||||
isDatasourcePanelDndItem,
|
||||
} from 'src/explore/components/DatasourcePanel/types';
|
||||
import { DndItemType } from 'src/explore/components/DndItemType';
|
||||
import DndSelectLabel from 'src/explore/components/controls/DndColumnSelectControl/DndSelectLabel';
|
||||
import { savedMetricType } from 'src/explore/components/controls/MetricControl/types';
|
||||
@@ -143,9 +146,9 @@ export const DndMetricSelect = (props: any) => {
|
||||
const [value, setValue] = useState<ValueType[]>(
|
||||
coerceAdhocMetrics(props.value),
|
||||
);
|
||||
const [droppedItem, setDroppedItem] = useState<DatasourcePanelDndItem | null>(
|
||||
null,
|
||||
);
|
||||
const [droppedItem, setDroppedItem] = useState<
|
||||
DatasourcePanelDndItem | typeof EMPTY_OBJECT
|
||||
>({});
|
||||
const [newMetricPopoverVisible, setNewMetricPopoverVisible] = useState(false);
|
||||
const prevColumns = usePrevious(columns);
|
||||
const prevSavedMetrics = usePrevious(savedMetrics);
|
||||
@@ -323,13 +326,16 @@ export const DndMetricSelect = (props: any) => {
|
||||
);
|
||||
|
||||
const handleClickGhostButton = useCallback(() => {
|
||||
setDroppedItem(null);
|
||||
setDroppedItem({});
|
||||
togglePopover(true);
|
||||
}, [togglePopover]);
|
||||
|
||||
const adhocMetric = useMemo(() => {
|
||||
if (droppedItem?.type === DndItemType.Column) {
|
||||
const itemValue = droppedItem?.value as ColumnMeta;
|
||||
if (
|
||||
isDatasourcePanelDndItem(droppedItem) &&
|
||||
droppedItem.type === DndItemType.Column
|
||||
) {
|
||||
const itemValue = droppedItem.value as ColumnMeta;
|
||||
const config: Partial<AdhocMetric> = {
|
||||
column: { column_name: itemValue?.column_name },
|
||||
};
|
||||
|
||||
@@ -31,44 +31,43 @@ import {
|
||||
SequentialScheme,
|
||||
} from '@superset-ui/core';
|
||||
import superset from '@superset-ui/core/lib/color/colorSchemes/categorical/superset';
|
||||
import ColorSchemeRegistry from '@superset-ui/core/lib/color/ColorSchemeRegistry';
|
||||
|
||||
function registerColorSchemes(
|
||||
registry: ColorSchemeRegistry<unknown>,
|
||||
colorSchemes: (CategoricalScheme | SequentialScheme)[],
|
||||
standardDefaultKey: string,
|
||||
) {
|
||||
colorSchemes.forEach(scheme => {
|
||||
registry.registerValue(scheme.id, scheme);
|
||||
});
|
||||
|
||||
const defaultKey =
|
||||
colorSchemes.find(scheme => scheme.isDefault)?.id || standardDefaultKey;
|
||||
registry.setDefaultKey(defaultKey);
|
||||
}
|
||||
|
||||
export default function setupColors(
|
||||
extraCategoricalColorSchemas: CategoricalScheme[] = [],
|
||||
extraCategoricalColorSchemes: CategoricalScheme[] = [],
|
||||
extraSequentialColorSchemes: SequentialScheme[] = [],
|
||||
) {
|
||||
// Register color schemes
|
||||
const categoricalSchemeRegistry = getCategoricalSchemeRegistry();
|
||||
|
||||
if (extraCategoricalColorSchemas?.length > 0) {
|
||||
extraCategoricalColorSchemas.forEach(scheme => {
|
||||
categoricalSchemeRegistry.registerValue(scheme.id, scheme);
|
||||
});
|
||||
}
|
||||
|
||||
[superset, airbnb, categoricalD3, echarts, google, lyft, preset].forEach(
|
||||
group => {
|
||||
group.forEach(scheme => {
|
||||
categoricalSchemeRegistry.registerValue(scheme.id, scheme);
|
||||
});
|
||||
},
|
||||
registerColorSchemes(
|
||||
getCategoricalSchemeRegistry(),
|
||||
[
|
||||
...superset,
|
||||
...airbnb,
|
||||
...categoricalD3,
|
||||
...echarts,
|
||||
...google,
|
||||
...lyft,
|
||||
...preset,
|
||||
...extraCategoricalColorSchemes,
|
||||
],
|
||||
'supersetColors',
|
||||
);
|
||||
registerColorSchemes(
|
||||
getSequentialSchemeRegistry(),
|
||||
[...sequentialCommon, ...sequentialD3, ...extraSequentialColorSchemes],
|
||||
'superset_seq_1',
|
||||
);
|
||||
categoricalSchemeRegistry.setDefaultKey('supersetColors');
|
||||
|
||||
const sequentialSchemeRegistry = getSequentialSchemeRegistry();
|
||||
|
||||
if (extraSequentialColorSchemes?.length > 0) {
|
||||
extraSequentialColorSchemes.forEach(scheme => {
|
||||
sequentialSchemeRegistry.registerValue(
|
||||
scheme.id,
|
||||
new SequentialScheme(scheme),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
[sequentialCommon, sequentialD3].forEach(group => {
|
||||
group.forEach(scheme => {
|
||||
sequentialSchemeRegistry.registerValue(scheme.id, scheme);
|
||||
});
|
||||
});
|
||||
sequentialSchemeRegistry.setDefaultKey('superset_seq_1');
|
||||
}
|
||||
|
||||
@@ -908,14 +908,14 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const message: Array<string> = Object.values(dbErrors);
|
||||
const message: Array<string> =
|
||||
typeof dbErrors === 'object' ? Object.values(dbErrors) : [];
|
||||
return (
|
||||
<Alert
|
||||
type="error"
|
||||
css={(theme: SupersetTheme) => antDErrorAlertStyles(theme)}
|
||||
message="Database Creation Error"
|
||||
description={message[0]}
|
||||
description={message?.[0] || dbErrors}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import copy
|
||||
import math
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from flask_babel import _
|
||||
@@ -131,15 +130,12 @@ def _get_samples(
|
||||
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
row_limit = query_obj.row_limit or math.inf
|
||||
query_obj = copy.copy(query_obj)
|
||||
query_obj.is_timeseries = False
|
||||
query_obj.orderby = []
|
||||
query_obj.groupby = []
|
||||
query_obj.metrics = []
|
||||
query_obj.post_processing = []
|
||||
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
|
||||
query_obj.row_offset = 0
|
||||
query_obj.columns = [o.column_name for o in datasource.columns]
|
||||
return _get_full(query_context, query_obj, force_cached)
|
||||
|
||||
|
||||
@@ -100,11 +100,11 @@ class QueryContext:
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
self.queries = [QueryObject(**query_obj) for query_obj in queries]
|
||||
self.force = force
|
||||
self.custom_cache_timeout = custom_cache_timeout
|
||||
self.result_type = result_type or ChartDataResultType.FULL
|
||||
self.result_format = result_format or ChartDataResultFormat.JSON
|
||||
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
|
||||
self.cache_values = {
|
||||
"datasource": datasource,
|
||||
"queries": queries,
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, NamedTuple, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
@@ -28,6 +28,7 @@ from superset.exceptions import QueryObjectValidationError
|
||||
from superset.typing import Metric, OrderBy
|
||||
from superset.utils import pandas_postprocessing
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
ChartDataResultType,
|
||||
DatasourceDict,
|
||||
DTTM_ALIAS,
|
||||
@@ -41,6 +42,10 @@ from superset.utils.date_parser import get_since_until, parse_human_timedelta
|
||||
from superset.utils.hashing import md5_sha_from_dict
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.common.query_context import QueryContext # pragma: no cover
|
||||
|
||||
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,6 +105,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||
self,
|
||||
query_context: "QueryContext",
|
||||
datasource: Optional[DatasourceDict] = None,
|
||||
result_type: Optional[ChartDataResultType] = None,
|
||||
annotation_layers: Optional[List[Dict[str, Any]]] = None,
|
||||
@@ -138,7 +144,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
self.result_type = result_type
|
||||
self.result_type = result_type or query_context.result_type
|
||||
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
|
||||
self.annotation_layers = [
|
||||
layer
|
||||
@@ -180,7 +186,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
||||
for x in metrics
|
||||
]
|
||||
|
||||
self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
|
||||
default_row_limit = (
|
||||
config["SAMPLES_ROW_LIMIT"]
|
||||
if self.result_type == ChartDataResultType.SAMPLES
|
||||
else config["ROW_LIMIT"]
|
||||
)
|
||||
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
|
||||
self.row_offset = row_offset or 0
|
||||
self.filter = filters or []
|
||||
self.timeseries_limit = timeseries_limit
|
||||
|
||||
@@ -115,9 +115,9 @@ VERSION_SHA = _try_json_readsha(VERSION_INFO_FILE, VERSION_SHA_LENGTH)
|
||||
# default viz used in chart explorer
|
||||
DEFAULT_VIZ_TYPE = "table"
|
||||
|
||||
# default row limit when requesting chart data
|
||||
ROW_LIMIT = 50000
|
||||
VIZ_ROW_LIMIT = 10000
|
||||
# max rows retreieved when requesting samples from datasource in explore view
|
||||
# default row limit when requesting samples from datasource in explore view
|
||||
SAMPLES_ROW_LIMIT = 1000
|
||||
# max rows retrieved by filter select auto complete
|
||||
FILTER_SELECT_ROW_LIMIT = 10000
|
||||
@@ -427,7 +427,14 @@ FEATURE_FLAGS: Dict[str, bool] = {}
|
||||
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
|
||||
# return feature_flags_dict
|
||||
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
|
||||
|
||||
# A function that receives a feature flag name and an optional default value.
|
||||
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
|
||||
# evaluation of all feature flags when just evaluating a single one.
|
||||
#
|
||||
# Note that the default `get_feature_flags` will evaluate each feature with this
|
||||
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
|
||||
# and IS_FEATURE_ENABLED_FUNC in conjunction.
|
||||
IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
|
||||
# A function that expands/overrides the frontend `bootstrap_data.common` object.
|
||||
# Can be used to implement custom frontend functionality,
|
||||
# or dynamically change certain configs.
|
||||
@@ -449,6 +456,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
|
||||
# "id": 'myVisualizationColors',
|
||||
# "description": '',
|
||||
# "label": 'My Visualization Colors',
|
||||
# "isDefault": True,
|
||||
# "colors":
|
||||
# ['#006699', '#009DD9', '#5AAA46', '#44AAAA', '#DDAA77', '#7799BB', '#88AA77',
|
||||
# '#552288', '#5AAA46', '#CC7788', '#EEDD55', '#9977BB', '#BBAA44', '#DDCCDD']
|
||||
@@ -483,6 +491,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
|
||||
# "description": '',
|
||||
# "isDiverging": True,
|
||||
# "label": 'My custom warm to hot',
|
||||
# "isDefault": True,
|
||||
# "colors":
|
||||
# ['#552288', '#5AAA46', '#CC7788', '#EEDD55', '#9977BB', '#BBAA44', '#DDCCDD',
|
||||
# '#006699', '#009DD9', '#5AAA46', '#44AAAA', '#DDAA77', '#7799BB', '#88AA77']
|
||||
@@ -656,9 +665,7 @@ QUERY_LOGGER = None
|
||||
# Set this API key to enable Mapbox visualizations
|
||||
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
|
||||
|
||||
# Maximum number of rows returned from a database
|
||||
# in async mode, no more than SQL_MAX_ROW will be returned and stored
|
||||
# in the results backend. This also becomes the limit when exporting CSVs
|
||||
# Maximum number of rows returned for any analytical database query
|
||||
SQL_MAX_ROW = 100000
|
||||
|
||||
# Maximum number of rows displayed in SQL Lab UI
|
||||
|
||||
@@ -72,6 +72,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
username = self._actor.username if self._actor is not None else None
|
||||
engine = database.get_sqla_engine(user_name=username)
|
||||
event_logger.log_with_context(
|
||||
action="test_connection_attempt",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
try:
|
||||
alive = engine.dialect.do_ping(conn)
|
||||
|
||||
@@ -33,6 +33,7 @@ from superset.databases.dao import DatabaseDAO
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs.base import BasicParametersMixin
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.extensions import event_logger
|
||||
from superset.models.core import Database
|
||||
|
||||
BYPASS_VALIDATION_ENGINES = {"bigquery"}
|
||||
@@ -89,6 +90,7 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
||||
self._properties.get("parameters", {})
|
||||
)
|
||||
if errors:
|
||||
event_logger.log_with_context(action="validation_error", engine=engine)
|
||||
raise InvalidParametersError(errors)
|
||||
|
||||
serialized_encrypted_extra = self._properties.get("encrypted_extra", "{}")
|
||||
|
||||
@@ -34,6 +34,8 @@ from flask import current_app, g, has_request_context, request
|
||||
from flask_babel import gettext as _
|
||||
from jinja2 import DebugUndefined
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.types import String
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
@@ -95,9 +97,11 @@ class ExtraCache:
|
||||
self,
|
||||
extra_cache_keys: Optional[List[Any]] = None,
|
||||
removed_filters: Optional[List[str]] = None,
|
||||
dialect: Optional[Dialect] = None,
|
||||
):
|
||||
self.extra_cache_keys = extra_cache_keys
|
||||
self.removed_filters = removed_filters if removed_filters is not None else []
|
||||
self.dialect = dialect
|
||||
|
||||
def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
|
||||
"""
|
||||
@@ -145,7 +149,11 @@ class ExtraCache:
|
||||
return key
|
||||
|
||||
def url_param(
|
||||
self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True
|
||||
self,
|
||||
param: str,
|
||||
default: Optional[str] = None,
|
||||
add_to_cache_keys: bool = True,
|
||||
escape_result: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Read a url or post parameter and use it in your SQL Lab query.
|
||||
@@ -166,6 +174,7 @@ class ExtraCache:
|
||||
:param param: the parameter to lookup
|
||||
:param default: the value to return in the absence of the parameter
|
||||
:param add_to_cache_keys: Whether the value should be included in the cache key
|
||||
:param escape_result: Should special characters in the result be escaped
|
||||
:returns: The URL parameters
|
||||
"""
|
||||
|
||||
@@ -178,6 +187,11 @@ class ExtraCache:
|
||||
form_data, _ = get_form_data()
|
||||
url_params = form_data.get("url_params") or {}
|
||||
result = url_params.get(param, default)
|
||||
if result and escape_result and self.dialect:
|
||||
# use the dialect specific quoting logic to escape string
|
||||
result = String().literal_processor(dialect=self.dialect)(value=result)[
|
||||
1:-1
|
||||
]
|
||||
if add_to_cache_keys:
|
||||
self.cache_key_wrapper(result)
|
||||
return result
|
||||
@@ -430,7 +444,11 @@ class BaseTemplateProcessor:
|
||||
class JinjaTemplateProcessor(BaseTemplateProcessor):
|
||||
def set_context(self, **kwargs: Any) -> None:
|
||||
super().set_context(**kwargs)
|
||||
extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters)
|
||||
extra_cache = ExtraCache(
|
||||
extra_cache_keys=self._extra_cache_keys,
|
||||
removed_filters=self._removed_filters,
|
||||
dialect=self._database.get_dialect(),
|
||||
)
|
||||
self._context.update(
|
||||
{
|
||||
"url_param": partial(safe_proxy, extra_cache.url_param),
|
||||
|
||||
@@ -1761,3 +1761,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
|
||||
return bool(strtobool(bool_str.lower()))
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
|
||||
"""
|
||||
Override row limit if max global limit is defined
|
||||
|
||||
:param limit: requested row limit
|
||||
:param max_limit: Maximum allowed row limit
|
||||
:return: Capped row limit
|
||||
|
||||
>>> apply_max_row_limit(100000, 10)
|
||||
10
|
||||
>>> apply_max_row_limit(10, 100000)
|
||||
10
|
||||
>>> apply_max_row_limit(0, 10000)
|
||||
10000
|
||||
"""
|
||||
if max_limit is None:
|
||||
max_limit = current_app.config["SQL_MAX_ROW"]
|
||||
if limit != 0:
|
||||
return min(max_limit, limit)
|
||||
return max_limit
|
||||
|
||||
@@ -24,24 +24,36 @@ class FeatureFlagManager:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._get_feature_flags_func = None
|
||||
self._is_feature_enabled_func = None
|
||||
self._feature_flags: Dict[str, Any] = {}
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
|
||||
self._is_feature_enabled_func = app.config["IS_FEATURE_ENABLED_FUNC"]
|
||||
self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
|
||||
self._feature_flags.update(app.config["FEATURE_FLAGS"])
|
||||
|
||||
def get_feature_flags(self) -> Dict[str, Any]:
|
||||
if self._get_feature_flags_func:
|
||||
return self._get_feature_flags_func(deepcopy(self._feature_flags))
|
||||
|
||||
if callable(self._is_feature_enabled_func):
|
||||
return dict(
|
||||
map(
|
||||
lambda kv: (kv[0], self._is_feature_enabled_func(kv[0], kv[1])),
|
||||
self._feature_flags.items(),
|
||||
)
|
||||
)
|
||||
return self._feature_flags
|
||||
|
||||
def is_feature_enabled(self, feature: str) -> bool:
|
||||
"""Utility function for checking whether a feature is turned on"""
|
||||
if self._is_feature_enabled_func:
|
||||
return (
|
||||
self._is_feature_enabled_func(feature, self._feature_flags[feature])
|
||||
if feature in self._feature_flags
|
||||
else False
|
||||
)
|
||||
feature_flags = self.get_feature_flags()
|
||||
|
||||
if feature_flags and feature in feature_flags:
|
||||
return feature_flags[feature]
|
||||
|
||||
return False
|
||||
|
||||
@@ -23,10 +23,11 @@ from typing import Any, cast, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from flask import g
|
||||
|
||||
from superset import app, is_feature_enabled
|
||||
from superset import is_feature_enabled
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import apply_max_row_limit
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.views.utils import get_cta_schema_name
|
||||
|
||||
@@ -97,7 +98,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@staticmethod
|
||||
def _get_limit_param(query_params: Dict[str, Any]) -> int:
|
||||
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
|
||||
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
|
||||
if limit < 0:
|
||||
logger.warning(
|
||||
"Invalid limit of %i specified. Defaulting to max limit.", limit
|
||||
|
||||
@@ -107,7 +107,7 @@ from superset.typing import FlaskResponse
|
||||
from superset.utils import core as utils, csv
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.cache import etag_cache
|
||||
from superset.utils.core import ReservedUrlParameters
|
||||
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import check_dashboard_access
|
||||
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
|
||||
@@ -898,8 +898,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
||||
return json_error_response(DATASOURCE_MISSING_ERR)
|
||||
|
||||
datasource.raise_for_access()
|
||||
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
|
||||
payload = json.dumps(
|
||||
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
|
||||
datasource.values_for_column(column, row_limit),
|
||||
default=utils.json_int_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,7 @@ These objects represent the backend of all the visualizations that
|
||||
Superset can render.
|
||||
"""
|
||||
import copy
|
||||
import inspect
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
@@ -70,6 +70,7 @@ from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
|
||||
from superset.utils import core as utils, csv
|
||||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DTTM_ALIAS,
|
||||
ExtraFiltersReasonType,
|
||||
JS_MAX_INTEGER,
|
||||
@@ -81,9 +82,6 @@ from superset.utils.date_parser import get_since_until, parse_past_timedelta
|
||||
from superset.utils.dates import datetime_to_epoch
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
|
||||
import dataclasses # isort:skip
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
@@ -110,7 +108,7 @@ METRIC_KEYS = [
|
||||
FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,")
|
||||
|
||||
|
||||
class BaseViz:
|
||||
class BaseViz: # pylint: disable=too-many-public-methods
|
||||
|
||||
"""All visualizations derive this base class"""
|
||||
|
||||
@@ -332,6 +330,7 @@ class BaseViz:
|
||||
limit = int(form_data.get("limit") or 0)
|
||||
timeseries_limit_metric = form_data.get("timeseries_limit_metric")
|
||||
row_limit = int(form_data.get("row_limit") or config["ROW_LIMIT"])
|
||||
row_limit = apply_max_row_limit(row_limit)
|
||||
|
||||
# default order direction
|
||||
order_desc = form_data.get("order_desc", True)
|
||||
@@ -556,7 +555,7 @@ class BaseViz:
|
||||
)
|
||||
self.errors.append(error)
|
||||
self.status = utils.QueryStatus.FAILED
|
||||
except Exception as ex:
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.exception(ex)
|
||||
|
||||
error = dataclasses.asdict(
|
||||
@@ -625,7 +624,7 @@ class BaseViz:
|
||||
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||
return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"])
|
||||
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=no-self-use
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
@property
|
||||
@@ -1242,7 +1241,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
|
||||
d["orderby"] = [(sort_by, is_asc)]
|
||||
return d
|
||||
|
||||
def to_series(
|
||||
def to_series( # pylint: disable=too-many-branches
|
||||
self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
|
||||
) -> List[Dict[str, Any]]:
|
||||
cols = []
|
||||
@@ -1446,6 +1445,7 @@ class MultiLineViz(NVD3Viz):
|
||||
return {}
|
||||
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
# pylint: disable=import-outside-toplevel,too-many-locals
|
||||
multiline_fd = self.form_data
|
||||
# Late import to avoid circular import issues
|
||||
from superset.charts.dao import ChartDAO
|
||||
@@ -1669,19 +1669,20 @@ class HistogramViz(BaseViz):
|
||||
|
||||
def query_obj(self) -> QueryObjectDict:
|
||||
"""Returns the query object for this visualization"""
|
||||
d = super().query_obj()
|
||||
d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"]))
|
||||
query_obj = super().query_obj()
|
||||
numeric_columns = self.form_data.get("all_columns_x")
|
||||
if numeric_columns is None:
|
||||
raise QueryObjectValidationError(
|
||||
_("Must have at least one numeric column specified")
|
||||
)
|
||||
self.columns = numeric_columns
|
||||
d["columns"] = numeric_columns + self.groupby
|
||||
self.columns = ( # pylint: disable=attribute-defined-outside-init
|
||||
numeric_columns
|
||||
)
|
||||
query_obj["columns"] = numeric_columns + self.groupby
|
||||
# override groupby entry to avoid aggregation
|
||||
d["groupby"] = None
|
||||
d["metrics"] = None
|
||||
return d
|
||||
query_obj["groupby"] = None
|
||||
query_obj["metrics"] = None
|
||||
return query_obj
|
||||
|
||||
def labelify(self, keys: Union[List[str], str], column: str) -> str:
|
||||
if isinstance(keys, str):
|
||||
@@ -1751,7 +1752,7 @@ class DistributionBarViz(BaseViz):
|
||||
|
||||
return d
|
||||
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-locals
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
@@ -2061,6 +2062,7 @@ class FilterBoxViz(BaseViz):
|
||||
return {}
|
||||
|
||||
def run_extra_queries(self) -> None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.common.query_context import QueryContext
|
||||
|
||||
qry = super().query_obj()
|
||||
@@ -2373,6 +2375,7 @@ class DeckGLMultiLayer(BaseViz):
|
||||
def get_data(self, df: pd.DataFrame) -> VizData:
|
||||
fd = self.form_data
|
||||
# Late imports to avoid circular import issues
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.models.slice import Slice
|
||||
|
||||
@@ -2393,6 +2396,7 @@ class BaseDeckGLViz(BaseViz):
|
||||
spatial_control_keys: List[str] = []
|
||||
|
||||
def get_metrics(self) -> List[str]:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.metric = self.form_data.get("size")
|
||||
return [self.metric] if self.metric else []
|
||||
|
||||
@@ -2557,15 +2561,18 @@ class DeckScatterViz(BaseDeckGLViz):
|
||||
is_timeseries = True
|
||||
|
||||
def query_obj(self) -> QueryObjectDict:
|
||||
fd = self.form_data
|
||||
self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity"))
|
||||
self.point_radius_fixed = fd.get("point_radius_fixed") or {
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.is_timeseries = bool(
|
||||
self.form_data.get("time_grain_sqla") or self.form_data.get("granularity")
|
||||
)
|
||||
self.point_radius_fixed = self.form_data.get("point_radius_fixed") or {
|
||||
"type": "fix",
|
||||
"value": 500,
|
||||
}
|
||||
return super().query_obj()
|
||||
|
||||
def get_metrics(self) -> List[str]:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.metric = None
|
||||
if self.point_radius_fixed.get("type") == "metric":
|
||||
self.metric = self.point_radius_fixed["value"]
|
||||
|
||||
@@ -28,9 +28,11 @@ import pytest
|
||||
from flask import Response
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_testing import TestCase
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.dialects.mysql import dialect
|
||||
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset.sql_parse import CtasMethod
|
||||
@@ -422,7 +424,7 @@ class SupersetTestCase(TestCase):
|
||||
self.login(username="admin")
|
||||
database_name = "db_for_macros_testing"
|
||||
db_id = 200
|
||||
return self.get_or_create(
|
||||
database = self.get_or_create(
|
||||
cls=models.Database,
|
||||
criteria={"database_name": database_name},
|
||||
session=db.session,
|
||||
@@ -430,7 +432,14 @@ class SupersetTestCase(TestCase):
|
||||
id=db_id,
|
||||
)
|
||||
|
||||
def delete_fake_db_for_macros(self):
|
||||
def mock_get_dialect() -> Dialect:
|
||||
return dialect()
|
||||
|
||||
database.get_dialect = mock_get_dialect
|
||||
return database
|
||||
|
||||
@staticmethod
|
||||
def delete_fake_db_for_macros():
|
||||
database = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == "db_for_macros_testing")
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
@@ -1203,6 +1203,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
@@ -1210,11 +1211,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
|
||||
)
|
||||
def test_chart_data_default_sample_limit(self):
|
||||
def test_chart_data_sql_max_row_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count doesn't exceed default limit
|
||||
Chart data API: Ensure row count doesn't exceed max global row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
)
|
||||
def test_chart_data_sample_default_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count defaults to config defaults
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_actions.config",
|
||||
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
|
||||
)
|
||||
def test_chart_data_sample_custom_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count is between
|
||||
default and SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
@@ -1223,6 +1259,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
|
||||
)
|
||||
def test_chart_data_sql_max_row_sample_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count doesn't
|
||||
exceed SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
def test_chart_data_incorrect_result_type(self):
|
||||
|
||||
@@ -16,17 +16,25 @@
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
from typing import Any, Dict, Tuple
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from marshmallow import ValidationError
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.fixtures.query_context import get_query_context
|
||||
|
||||
|
||||
class TestSchema(SupersetTestCase):
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
|
||||
)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_limit_and_offset(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
@@ -36,7 +44,7 @@ class TestSchema(SupersetTestCase):
|
||||
payload["queries"][0].pop("row_offset", None)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
|
||||
self.assertEqual(query_object.row_limit, 5000)
|
||||
self.assertEqual(query_object.row_offset, 0)
|
||||
|
||||
# Valid limit and offset
|
||||
@@ -55,12 +63,14 @@ class TestSchema(SupersetTestCase):
|
||||
self.assertIn("row_limit", context.exception.messages["queries"][0])
|
||||
self.assertIn("row_offset", context.exception.messages["queries"][0])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_null_timegrain(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["queries"][0]["extras"]["time_grain_sqla"] = None
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_series_limit(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
@@ -82,6 +92,7 @@ class TestSchema(SupersetTestCase):
|
||||
}
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_null_post_processing_op(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
|
||||
@@ -16,10 +16,16 @@
|
||||
# under the License.
|
||||
from unittest.mock import patch
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from parameterized import parameterized
|
||||
|
||||
from superset import get_feature_flags, is_feature_enabled
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
def dummy_is_feature_enabled(feature_flag_name: str, default: bool = True) -> bool:
|
||||
return True if feature_flag_name.startswith("True_") else default
|
||||
|
||||
|
||||
class TestFeatureFlag(SupersetTestCase):
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
@@ -38,3 +44,40 @@ class TestFeatureFlag(SupersetTestCase):
|
||||
def test_feature_flags(self):
|
||||
self.assertEqual(is_feature_enabled("foo"), "bar")
|
||||
self.assertEqual(is_feature_enabled("super"), "set")
|
||||
|
||||
|
||||
@patch.dict(
|
||||
"superset.extensions.feature_flag_manager._feature_flags",
|
||||
{"True_Flag1": False, "True_Flag2": True, "Flag3": False, "Flag4": True},
|
||||
clear=True,
|
||||
)
|
||||
class TestFeatureFlagBackend(SupersetTestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
("True_Flag1", True),
|
||||
("True_Flag2", True),
|
||||
("Flag3", False),
|
||||
("Flag4", True),
|
||||
("True_DoesNotExist", False),
|
||||
]
|
||||
)
|
||||
@patch(
|
||||
"superset.extensions.feature_flag_manager._is_feature_enabled_func",
|
||||
dummy_is_feature_enabled,
|
||||
)
|
||||
def test_feature_flags_override(self, feature_flag_name, expected):
|
||||
self.assertEqual(is_feature_enabled(feature_flag_name), expected)
|
||||
|
||||
@patch(
|
||||
"superset.extensions.feature_flag_manager._is_feature_enabled_func",
|
||||
dummy_is_feature_enabled,
|
||||
)
|
||||
@patch(
|
||||
"superset.extensions.feature_flag_manager._get_feature_flags_func", None,
|
||||
)
|
||||
def test_get_feature_flags(self):
|
||||
feature_flags = get_feature_flags()
|
||||
self.assertEqual(
|
||||
feature_flags,
|
||||
{"True_Flag1": True, "True_Flag2": True, "Flag3": False, "Flag4": True},
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects.postgresql import dialect
|
||||
|
||||
import tests.integration_tests.test_app
|
||||
from superset import app
|
||||
@@ -199,6 +200,36 @@ class TestJinja2Context(SupersetTestCase):
|
||||
cache = ExtraCache()
|
||||
self.assertEqual(cache.url_param("foo"), "bar")
|
||||
|
||||
def test_url_param_escaped_form_data(self) -> None:
|
||||
with app.test_request_context(
|
||||
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
|
||||
):
|
||||
cache = ExtraCache(dialect=dialect())
|
||||
self.assertEqual(cache.url_param("foo"), "O''Brien")
|
||||
|
||||
def test_url_param_escaped_default_form_data(self) -> None:
|
||||
with app.test_request_context(
|
||||
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
|
||||
):
|
||||
cache = ExtraCache(dialect=dialect())
|
||||
self.assertEqual(cache.url_param("bar", "O'Malley"), "O''Malley")
|
||||
|
||||
def test_url_param_unescaped_form_data(self) -> None:
|
||||
with app.test_request_context(
|
||||
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
|
||||
):
|
||||
cache = ExtraCache(dialect=dialect())
|
||||
self.assertEqual(cache.url_param("foo", escape_result=False), "O'Brien")
|
||||
|
||||
def test_url_param_unescaped_default_form_data(self) -> None:
|
||||
with app.test_request_context(
|
||||
query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})}
|
||||
):
|
||||
cache = ExtraCache(dialect=dialect())
|
||||
self.assertEqual(
|
||||
cache.url_param("bar", "O'Malley", escape_result=False), "O'Malley"
|
||||
)
|
||||
|
||||
def test_safe_proxy_primitive(self) -> None:
|
||||
def func(input: Any) -> Any:
|
||||
return input
|
||||
|
||||
@@ -90,6 +90,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
|
||||
self.assertEqual(post_proc["options"], payload_post_proc["options"])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_cache(self):
|
||||
table_name = "birth_names"
|
||||
table = self.get_table(name=table_name)
|
||||
|
||||
Reference in New Issue
Block a user