mirror of
https://github.com/apache/superset.git
synced 2026-05-18 14:25:13 +00:00
Compare commits
9 Commits
fix/mcp-ex
...
rusackas/f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a67c02b86a | ||
|
|
925e20bb04 | ||
|
|
23b8d323a3 | ||
|
|
fc73474d53 | ||
|
|
9dd93d38b6 | ||
|
|
909675c2a9 | ||
|
|
3eb1c35512 | ||
|
|
183ad77ed4 | ||
|
|
a1bf361fe2 |
@@ -17,7 +17,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
@@ -243,6 +245,54 @@ class DatabaseDAO(BaseDAO[Database]):
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_odps_partitioned_table(
|
||||
cls, database: Database, table_name: str
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
This function is used to determine and retrieve
|
||||
partition information of the ODPS table.
|
||||
The return values are whether the partition
|
||||
table is partitioned and the names of all partition fields.
|
||||
"""
|
||||
if not database:
|
||||
raise ValueError("Database not found")
|
||||
if database.backend != "odps":
|
||||
return False, []
|
||||
try:
|
||||
from odps import ODPS
|
||||
except ImportError:
|
||||
logger.warning("pyodps is not installed, cannot check ODPS partition info")
|
||||
return False, []
|
||||
uri = database.sqlalchemy_uri
|
||||
access_key = database.password
|
||||
pattern = re.compile(
|
||||
r"odps://(?P<username>[^:]+):(?P<password>[^@]+)@(?P<project>[^/]+)/(?:\?"
|
||||
r"endpoint=(?P<endpoint>[^&]+))"
|
||||
)
|
||||
if not uri or not isinstance(uri, str):
|
||||
logger.warning(
|
||||
"Invalid or missing sqlalchemy URI, please provide a correct URI"
|
||||
)
|
||||
return False, []
|
||||
if match := pattern.match(unquote(uri)):
|
||||
access_id = match.group("username")
|
||||
project = match.group("project")
|
||||
endpoint = match.group("endpoint")
|
||||
odps_client = ODPS(access_id, access_key, project, endpoint=endpoint)
|
||||
table = odps_client.get_table(table_name)
|
||||
if table.exist_partition:
|
||||
partition_spec = table.table_schema.partitions
|
||||
partition_fields = [partition.name for partition in partition_spec]
|
||||
return True, partition_fields
|
||||
return False, []
|
||||
logger.warning(
|
||||
"ODPS sqlalchemy_uri did not match the expected pattern; "
|
||||
"unable to determine partition info for table %r",
|
||||
table_name,
|
||||
)
|
||||
return False, []
|
||||
|
||||
|
||||
class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):
|
||||
"""
|
||||
|
||||
@@ -123,7 +123,7 @@ from superset.exceptions import (
|
||||
)
|
||||
from superset.extensions import security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.sql.parse import Table
|
||||
from superset.sql.parse import Partition, Table
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils import json
|
||||
from superset.utils.core import (
|
||||
@@ -1079,15 +1079,32 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
parameters = QualifiedTableSchema().load(request.args)
|
||||
except ValidationError as ex:
|
||||
raise InvalidPayloadSchemaError(ex) from ex
|
||||
|
||||
table = Table(parameters["name"], parameters["schema"], parameters["catalog"])
|
||||
table_name = str(parameters["name"])
|
||||
table = Table(table_name, parameters["schema"], parameters["catalog"])
|
||||
try:
|
||||
security_manager.raise_for_access(database=database, table=table)
|
||||
except SupersetSecurityException as ex:
|
||||
# instead of raising 403, raise 404 to hide table existence
|
||||
raise TableNotFoundException("No such table") from ex
|
||||
try:
|
||||
is_partitioned_table, partition_fields = (
|
||||
DatabaseDAO.is_odps_partitioned_table(database, table_name)
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.warning(
|
||||
"Error determining ODPS partition info for table %s: %s; "
|
||||
"falling back to non-partitioned path",
|
||||
table_name,
|
||||
error_msg_from_exception(ex),
|
||||
)
|
||||
is_partitioned_table, partition_fields = False, []
|
||||
partition = Partition(is_partitioned_table, partition_fields)
|
||||
if is_partitioned_table:
|
||||
from superset.db_engine_specs.odps import OdpsEngineSpec
|
||||
|
||||
payload = database.db_engine_spec.get_table_metadata(database, table)
|
||||
payload = OdpsEngineSpec.get_table_metadata(database, table, partition)
|
||||
else:
|
||||
payload = database.db_engine_spec.get_table_metadata(database, table)
|
||||
|
||||
return self.response(200, **payload)
|
||||
|
||||
|
||||
@@ -81,6 +81,14 @@ def load_engine_specs() -> list[type[BaseEngineSpec]]:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to load Superset DB engine spec: %s", ep.name)
|
||||
continue
|
||||
# Validate that the engine spec is a proper subclass of BaseEngineSpec
|
||||
if not is_engine_spec(engine_spec):
|
||||
logger.warning(
|
||||
"Skipping invalid DB engine spec %s: "
|
||||
"not a valid BaseEngineSpec subclass",
|
||||
ep.name,
|
||||
)
|
||||
continue
|
||||
engine_specs.append(engine_spec)
|
||||
|
||||
return engine_specs
|
||||
|
||||
192
superset/db_engine_specs/odps.py
Normal file
192
superset/db_engine_specs/odps.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.engine import Dialect
|
||||
|
||||
from superset.databases.schemas import (
|
||||
TableMetadataColumnsResponse,
|
||||
TableMetadataResponse,
|
||||
)
|
||||
from superset.databases.utils import (
|
||||
get_col_type,
|
||||
get_foreign_keys_metadata,
|
||||
get_indexes_metadata,
|
||||
)
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
|
||||
from superset.sql.parse import Partition, SQLScript, Table
|
||||
from superset.superset_typing import ResultSetColumnType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OdpsBaseEngineSpec(BaseEngineSpec):
|
||||
@classmethod
|
||||
def get_table_metadata(
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
partition: Partition | None = None,
|
||||
) -> TableMetadataResponse:
|
||||
"""
|
||||
Returns basic table metadata
|
||||
:param database: Database instance
|
||||
:param table: A Table instance
|
||||
:param partition: A Table partition info
|
||||
:return: Basic table metadata
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OdpsEngineSpec(BasicParametersMixin, OdpsBaseEngineSpec):
|
||||
engine = "odps"
|
||||
engine_name = "ODPS (MaxCompute)"
|
||||
default_driver = "odps"
|
||||
|
||||
@classmethod
|
||||
def get_table_metadata(
|
||||
cls, database: Any, table: Table, partition: Partition | None = None
|
||||
) -> TableMetadataResponse:
|
||||
"""
|
||||
Get table metadata information, including type, pk, fks.
|
||||
This function raises SQLAlchemyError when a schema is not found.
|
||||
|
||||
:param partition: The table's partition info
|
||||
:param database: The database model
|
||||
:param table: Table instance
|
||||
:return: Dict table metadata ready for API response
|
||||
"""
|
||||
keys = []
|
||||
columns = database.get_columns(table)
|
||||
primary_key = database.get_pk_constraint(table)
|
||||
if primary_key and primary_key.get("constrained_columns"):
|
||||
primary_key["column_names"] = primary_key.pop("constrained_columns")
|
||||
primary_key["type"] = "pk"
|
||||
keys += [primary_key]
|
||||
foreign_keys = get_foreign_keys_metadata(database, table)
|
||||
indexes = get_indexes_metadata(database, table)
|
||||
keys += foreign_keys + indexes
|
||||
payload_columns: list[TableMetadataColumnsResponse] = []
|
||||
table_comment = database.get_table_comment(table)
|
||||
for col in columns:
|
||||
dtype = get_col_type(col)
|
||||
payload_columns.append(
|
||||
{
|
||||
"name": col["column_name"],
|
||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||
"longType": dtype,
|
||||
"keys": [
|
||||
k for k in keys if col["column_name"] in k["column_names"]
|
||||
],
|
||||
"comment": col.get("comment"),
|
||||
}
|
||||
)
|
||||
|
||||
with database.get_sqla_engine(
|
||||
catalog=table.catalog, schema=table.schema
|
||||
) as engine:
|
||||
return {
|
||||
"name": table.table,
|
||||
"columns": payload_columns,
|
||||
"selectStar": cls.select_star(
|
||||
database=database,
|
||||
table=table,
|
||||
dialect=engine.dialect,
|
||||
limit=100,
|
||||
show_cols=False,
|
||||
indent=True,
|
||||
latest_partition=True,
|
||||
cols=columns,
|
||||
partition=partition,
|
||||
),
|
||||
"primaryKey": primary_key,
|
||||
"foreignKeys": foreign_keys,
|
||||
"indexes": keys,
|
||||
"comment": table_comment,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
dialect: Dialect,
|
||||
limit: int = 100,
|
||||
show_cols: bool = False,
|
||||
indent: bool = True,
|
||||
latest_partition: bool = True,
|
||||
cols: list[ResultSetColumnType] | None = None,
|
||||
partition: Partition | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
|
||||
|
||||
WARNING: expects only unquoted table and schema names.
|
||||
|
||||
:param partition: The table's partition info
|
||||
:param database: Database instance
|
||||
:param table: Table instance
|
||||
:param dialect: SqlAlchemy Dialect instance
|
||||
:param limit: limit to impose on query
|
||||
:param show_cols: Show columns in query; otherwise use "*"
|
||||
:param indent: Add indentation to query
|
||||
:param latest_partition: Only query the latest partition
|
||||
:param cols: Columns to include in query
|
||||
:return: SQL query
|
||||
"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
fields: str | list[Any] = "*"
|
||||
cols = cols or []
|
||||
if (show_cols or latest_partition) and not cols:
|
||||
cols = database.get_columns(table)
|
||||
|
||||
if show_cols:
|
||||
fields = cls._get_fields(cols)
|
||||
full_table_name = cls.quote_table(table, dialect)
|
||||
qry = select(fields).select_from(text(full_table_name))
|
||||
if database.backend == "odps":
|
||||
if (
|
||||
partition is not None
|
||||
and partition.is_partitioned_table
|
||||
and partition.partition_column is not None
|
||||
and len(partition.partition_column) > 0
|
||||
):
|
||||
partition_str = partition.partition_column[0]
|
||||
partition_str_where = f"CAST({partition_str} AS STRING) LIKE '%'"
|
||||
qry = qry.where(text(partition_str_where))
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
if latest_partition:
|
||||
partition_query = cls.where_latest_partition(
|
||||
database,
|
||||
table,
|
||||
qry,
|
||||
columns=cols,
|
||||
)
|
||||
if partition_query is not None:
|
||||
qry = partition_query
|
||||
sql = database.compile_sqla_query(qry, table.catalog, table.schema)
|
||||
if indent:
|
||||
sql = SQLScript(sql, engine=cls.engine).format()
|
||||
return sql
|
||||
@@ -322,6 +322,34 @@ class Table:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Partition:
|
||||
"""
|
||||
Partition object, with two attribute keys:
|
||||
is_partitioned_table and partition_column,
|
||||
used to provide partition information
|
||||
Here is an example of an object:
|
||||
{"is_partitioned_table": true, "partition_column": ["month", "day"]}
|
||||
"""
|
||||
|
||||
is_partitioned_table: bool
|
||||
partition_column: list[str] | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a string representation of the Partition object.
|
||||
"""
|
||||
partition_column_str = (
|
||||
", ".join(map(str, self.partition_column))
|
||||
if self.partition_column
|
||||
else "None"
|
||||
)
|
||||
return (
|
||||
f"Partition(is_partitioned_table={self.is_partitioned_table}, "
|
||||
f"partition_column=[{partition_column_str}])"
|
||||
)
|
||||
|
||||
|
||||
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
||||
# an "internal representation", which is the AST of the SQL statement. For most of the
|
||||
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
|
||||
|
||||
@@ -27,6 +27,7 @@ from superset.db_engine_specs.base import (
|
||||
builtin_time_grains,
|
||||
)
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.odps import OdpsBaseEngineSpec, OdpsEngineSpec
|
||||
from superset.db_engine_specs.sqlite import SqliteEngineSpec
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql.parse import Table
|
||||
@@ -80,7 +81,11 @@ class SupersetTestCases(SupersetTestCase):
|
||||
time_grains = set(builtin_time_grains.keys())
|
||||
# loop over all subclasses of BaseEngineSpec
|
||||
for engine in load_engine_specs():
|
||||
if engine is not BaseEngineSpec:
|
||||
if (
|
||||
engine is not BaseEngineSpec
|
||||
and engine is not OdpsBaseEngineSpec
|
||||
and engine is not OdpsEngineSpec
|
||||
):
|
||||
# make sure time grain functions have been defined
|
||||
assert len(engine.get_time_grain_expressions()) > 0
|
||||
# make sure all defined time grains are supported
|
||||
|
||||
185
tests/unit_tests/db_engine_specs/test_odps.py
Normal file
185
tests/unit_tests/db_engine_specs/test_odps.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from superset.db_engine_specs.odps import OdpsBaseEngineSpec, OdpsEngineSpec
|
||||
from superset.sql.parse import Partition, Table
|
||||
|
||||
|
||||
def test_odps_base_engine_spec_get_table_metadata_raises() -> None:
|
||||
"""OdpsBaseEngineSpec.get_table_metadata must not be called directly."""
|
||||
with pytest.raises(NotImplementedError):
|
||||
OdpsBaseEngineSpec.get_table_metadata(
|
||||
database=MagicMock(),
|
||||
table=Table("my_table", None, None),
|
||||
)
|
||||
|
||||
|
||||
def test_odps_engine_spec_select_star_no_partition() -> None:
|
||||
"""select_star for a non-partitioned ODPS table produces a plain SELECT *."""
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.get_columns.return_value = []
|
||||
database.compile_sqla_query = lambda query, catalog, schema: str(
|
||||
query.compile(dialect=sqlite.dialect())
|
||||
)
|
||||
dialect = sqlite.dialect()
|
||||
|
||||
sql = OdpsEngineSpec.select_star(
|
||||
database=database,
|
||||
table=Table("my_table", None, None),
|
||||
dialect=dialect,
|
||||
limit=100,
|
||||
show_cols=False,
|
||||
indent=False,
|
||||
latest_partition=False,
|
||||
partition=None,
|
||||
)
|
||||
|
||||
assert "SELECT" in sql
|
||||
assert "my_table" in sql
|
||||
|
||||
|
||||
def test_odps_engine_spec_select_star_with_partition() -> None:
|
||||
"""select_star for a partitioned ODPS table adds a WHERE clause."""
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.get_columns.return_value = []
|
||||
database.compile_sqla_query = lambda query, catalog, schema: str(
|
||||
query.compile(dialect=sqlite.dialect())
|
||||
)
|
||||
dialect = sqlite.dialect()
|
||||
partition = Partition(is_partitioned_table=True, partition_column=["month"])
|
||||
|
||||
sql = OdpsEngineSpec.select_star(
|
||||
database=database,
|
||||
table=Table("my_table", None, None),
|
||||
dialect=dialect,
|
||||
limit=100,
|
||||
show_cols=False,
|
||||
indent=False,
|
||||
latest_partition=False,
|
||||
partition=partition,
|
||||
)
|
||||
|
||||
assert "WHERE" in sql
|
||||
|
||||
|
||||
def test_is_odps_partitioned_table_non_odps_backend() -> None:
|
||||
"""Returns (False, []) immediately for non-ODPS databases; no network call made."""
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
database = MagicMock()
|
||||
database.backend = "postgresql"
|
||||
|
||||
result = DatabaseDAO.is_odps_partitioned_table(database, "some_table")
|
||||
|
||||
assert result == (False, [])
|
||||
|
||||
|
||||
def test_is_odps_partitioned_table_missing_pyodps() -> None:
|
||||
"""Returns (False, []) with a warning when pyodps is not installed."""
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.sqlalchemy_uri = (
|
||||
"odps://mykey:mysecret@myproject/?endpoint=http://service.odps.test"
|
||||
)
|
||||
database.password = "mysecret" # noqa: S105
|
||||
|
||||
with patch.dict("sys.modules", {"odps": None}):
|
||||
result = DatabaseDAO.is_odps_partitioned_table(database, "some_table")
|
||||
|
||||
assert result == (False, [])
|
||||
|
||||
|
||||
def test_is_odps_partitioned_table_uri_no_match(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Logs a warning and returns (False, []) when the URI doesn't match the pattern."""
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.sqlalchemy_uri = "odps://invalid-uri-format"
|
||||
database.password = "secret" # noqa: S105
|
||||
|
||||
mock_odps_module = MagicMock()
|
||||
with patch.dict("sys.modules", {"odps": mock_odps_module}):
|
||||
with caplog.at_level(logging.WARNING, logger="superset.daos.database"):
|
||||
result = DatabaseDAO.is_odps_partitioned_table(database, "some_table")
|
||||
|
||||
assert result == (False, [])
|
||||
assert "did not match" in caplog.text
|
||||
|
||||
|
||||
def test_is_odps_partitioned_table_partitioned(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Returns (True, [field_names]) for a partitioned ODPS table."""
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.sqlalchemy_uri = (
|
||||
"odps://mykey:mysecret@myproject/?endpoint=http://service.odps.test"
|
||||
)
|
||||
database.password = "mysecret" # noqa: S105
|
||||
|
||||
mock_partition = MagicMock()
|
||||
mock_partition.name = "month"
|
||||
mock_table = MagicMock()
|
||||
mock_table.exist_partition = True
|
||||
mock_table.table_schema.partitions = [mock_partition]
|
||||
|
||||
mock_odps_client = MagicMock()
|
||||
mock_odps_client.get_table.return_value = mock_table
|
||||
mock_odps_class = MagicMock(return_value=mock_odps_client)
|
||||
|
||||
with patch.dict("sys.modules", {"odps": MagicMock(ODPS=mock_odps_class)}):
|
||||
with patch("superset.daos.database.ODPS", mock_odps_class, create=True):
|
||||
result = DatabaseDAO.is_odps_partitioned_table(database, "my_table")
|
||||
|
||||
assert result == (True, ["month"])
|
||||
|
||||
|
||||
def test_is_odps_partitioned_table_not_partitioned(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Returns (False, []) for a non-partitioned ODPS table."""
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
database = MagicMock()
|
||||
database.backend = "odps"
|
||||
database.sqlalchemy_uri = (
|
||||
"odps://mykey:mysecret@myproject/?endpoint=http://service.odps.test"
|
||||
)
|
||||
database.password = "mysecret" # noqa: S105
|
||||
|
||||
mock_table = MagicMock()
|
||||
mock_table.exist_partition = False
|
||||
mock_odps_client = MagicMock()
|
||||
mock_odps_client.get_table.return_value = mock_table
|
||||
mock_odps_class = MagicMock(return_value=mock_odps_client)
|
||||
|
||||
with patch.dict("sys.modules", {"odps": MagicMock(ODPS=mock_odps_class)}):
|
||||
result = DatabaseDAO.is_odps_partitioned_table(database, "my_table")
|
||||
|
||||
assert result == (False, [])
|
||||
@@ -30,6 +30,7 @@ from superset.sql.parse import (
|
||||
KQLTokenType,
|
||||
KustoKQLStatement,
|
||||
LimitMethod,
|
||||
Partition,
|
||||
process_jinja_sql,
|
||||
remove_quotes,
|
||||
RLSMethod,
|
||||
@@ -137,6 +138,36 @@ def test_table_qualify() -> None:
|
||||
assert qualified.catalog == table.catalog
|
||||
|
||||
|
||||
def test_partition() -> None:
|
||||
"""
|
||||
Test the `Partition` class and its string conversion.
|
||||
"""
|
||||
# Test partitioned table with partition columns
|
||||
partition = Partition(is_partitioned_table=True, partition_column=["col1", "col2"])
|
||||
assert partition.is_partitioned_table is True
|
||||
assert partition.partition_column == ["col1", "col2"]
|
||||
assert (
|
||||
str(partition)
|
||||
== "Partition(is_partitioned_table=True, partition_column=[col1, col2])"
|
||||
)
|
||||
|
||||
# Test non-partitioned table
|
||||
partition_none = Partition(is_partitioned_table=False, partition_column=None)
|
||||
assert partition_none.is_partitioned_table is False
|
||||
assert partition_none.partition_column is None
|
||||
assert (
|
||||
str(partition_none)
|
||||
== "Partition(is_partitioned_table=False, partition_column=[None])"
|
||||
)
|
||||
|
||||
# Test equality
|
||||
partition1 = Partition(is_partitioned_table=True, partition_column=["col1"])
|
||||
partition2 = Partition(is_partitioned_table=True, partition_column=["col1"])
|
||||
partition3 = Partition(is_partitioned_table=True, partition_column=["col2"])
|
||||
assert partition1 == partition2
|
||||
assert partition1 != partition3
|
||||
|
||||
|
||||
def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]:
|
||||
"""
|
||||
Helper function to extract tables from SQL.
|
||||
|
||||
Reference in New Issue
Block a user