mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
feat: add latest partition support for BigQuery (#30760)
This commit is contained in:
committed by
GitHub
parent
a36e636a58
commit
c83eda9551
@@ -1630,6 +1630,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
:return: SqlAlchemy query with additional where clause referencing the latest
|
||||
partition
|
||||
"""
|
||||
# TODO: Fix circular import caused by importing Database, TableColumn
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict
|
||||
import pandas as pd
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from deprecation import deprecated
|
||||
from flask_babel import gettext as __
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy import column, types
|
||||
from sqlalchemy import column, func, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.sql import sqltypes
|
||||
from sqlalchemy.sql import column as sql_column, select, sqltypes
|
||||
from sqlalchemy.sql.expression import table as sql_table
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.databases.schemas import encrypted_field_properties, EncryptedString
|
||||
@@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType
|
||||
from superset.utils import core as utils, json
|
||||
from superset.utils.hashing import md5_sha_from_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.sql.expression import Select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import google.auth
|
||||
from google.cloud import bigquery
|
||||
@@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
return "_" + md5_sha_from_str(label)
|
||||
|
||||
@classmethod
|
||||
@deprecated(deprecated_in="3.0")
|
||||
def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Normalizes indexes for more consistency across db engines
|
||||
|
||||
:param indexes: Raw indexes as returned by SQLAlchemy
|
||||
:return: cleaner, more aligned index definition
|
||||
"""
|
||||
normalized_idxs = []
|
||||
# Fixing a bug/behavior observed in pybigquery==0.4.15 where
|
||||
# the index's `column_names` == [None]
|
||||
# Here we're returning only non-None indexes
|
||||
for ix in indexes:
|
||||
column_names = ix.get("column_names") or []
|
||||
ix["column_names"] = [col for col in column_names if col is not None]
|
||||
if ix["column_names"]:
|
||||
normalized_idxs.append(ix)
|
||||
return normalized_idxs
|
||||
|
||||
@classmethod
|
||||
def get_indexes(
|
||||
def where_latest_partition(
|
||||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
table: Table,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get the indexes associated with the specified schema/table.
|
||||
query: Select,
|
||||
columns: list[ResultSetColumnType] | None = None,
|
||||
) -> Select | None:
|
||||
if partition_column := cls.get_time_partition_column(database, table):
|
||||
max_partition_id = cls.get_max_partition_id(database, table)
|
||||
query = query.where(
|
||||
column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id)
|
||||
)
|
||||
|
||||
:param database: The database to inspect
|
||||
:param inspector: The SQLAlchemy inspector
|
||||
:param table: The table instance to inspect
|
||||
:returns: The indexes
|
||||
"""
|
||||
return query
|
||||
|
||||
return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
|
||||
@classmethod
|
||||
def get_max_partition_id(
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
) -> Select | None:
|
||||
# Compose schema from catalog and schema
|
||||
schema_parts = []
|
||||
if table.catalog:
|
||||
schema_parts.append(table.catalog)
|
||||
if table.schema:
|
||||
schema_parts.append(table.schema)
|
||||
schema_parts.append("INFORMATION_SCHEMA")
|
||||
schema = ".".join(schema_parts)
|
||||
# Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS
|
||||
partitions_table = sql_table(
|
||||
"PARTITIONS",
|
||||
sql_column("partition_id"),
|
||||
sql_column("table_name"),
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
# Build the query
|
||||
query = select(
|
||||
func.max(partitions_table.c.partition_id).label("max_partition_id")
|
||||
).where(partitions_table.c.table_name == table.table)
|
||||
|
||||
# Compile to BigQuery SQL
|
||||
compiled_query = query.compile(
|
||||
dialect=database.get_dialect(),
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
|
||||
# Run the query and handle result
|
||||
with database.get_raw_connection(
|
||||
catalog=table.catalog,
|
||||
schema=table.schema,
|
||||
) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(str(compiled_query))
|
||||
if row := cursor.fetchone():
|
||||
return row[0]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_time_partition_column(
|
||||
cls,
|
||||
database: Database,
|
||||
table: Table,
|
||||
) -> str | None:
|
||||
with cls.get_engine(
|
||||
database, catalog=table.catalog, schema=table.schema
|
||||
) as engine:
|
||||
client = cls._get_client(engine, database)
|
||||
bq_table = client.get_table(f"{table.schema}.{table.table}")
|
||||
|
||||
if bq_table.time_partitioning:
|
||||
return bq_table.time_partitioning.field
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_extra_table_metadata(
|
||||
@@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
||||
database: Database,
|
||||
table: Table,
|
||||
) -> dict[str, Any]:
|
||||
indexes = database.get_indexes(table)
|
||||
if not indexes:
|
||||
return {}
|
||||
partitions_columns = [
|
||||
index.get("column_names", [])
|
||||
for index in indexes
|
||||
if index.get("name") == "partition"
|
||||
]
|
||||
cluster_columns = [
|
||||
index.get("column_names", [])
|
||||
for index in indexes
|
||||
if index.get("name") == "clustering"
|
||||
]
|
||||
return {
|
||||
"partitions": {"cols": partitions_columns},
|
||||
"clustering": {"cols": cluster_columns},
|
||||
}
|
||||
payload = {}
|
||||
partition_column = cls.get_time_partition_column(database, table)
|
||||
with cls.get_engine(
|
||||
database, catalog=table.catalog, schema=table.schema
|
||||
) as engine:
|
||||
if partition_column:
|
||||
max_partition_id = cls.get_max_partition_id(database, table)
|
||||
sql = cls.select_star(
|
||||
database,
|
||||
table,
|
||||
engine,
|
||||
indent=False,
|
||||
show_cols=False,
|
||||
latest_partition=True,
|
||||
)
|
||||
payload.update(
|
||||
{
|
||||
"partitions": {
|
||||
"cols": [partition_column],
|
||||
"latest": {partition_column: max_partition_id},
|
||||
"partitionQuery": sql,
|
||||
},
|
||||
"indexes": [
|
||||
{
|
||||
"name": "partitioned",
|
||||
"cols": [partition_column],
|
||||
"type": "partitioned",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import unittest.mock as mock
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
@@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_engine_with_credentials(*args, **kwargs):
|
||||
engine_mock = mock.Mock()
|
||||
engine_mock.dialect.credentials_info = {
|
||||
"key": "value"
|
||||
} # Add the credentials_info attribute
|
||||
yield engine_mock
|
||||
|
||||
|
||||
class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
||||
def test_bigquery_sqla_column_label(self):
|
||||
"""
|
||||
@@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
|
||||
result = BigQueryEngineSpec.fetch_data(None, 0)
|
||||
assert result == [1, 2]
|
||||
|
||||
def test_get_extra_table_metadata(self):
|
||||
@mock.patch.object(
|
||||
BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials
|
||||
)
|
||||
@mock.patch.object(BigQueryEngineSpec, "get_time_partition_column")
|
||||
@mock.patch.object(BigQueryEngineSpec, "get_max_partition_id")
|
||||
@mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`")
|
||||
def test_get_extra_table_metadata(
|
||||
self,
|
||||
mock_quote_table,
|
||||
mock_get_max_partition_id,
|
||||
mock_get_time_partition_column,
|
||||
mock_get_engine,
|
||||
):
|
||||
"""
|
||||
DB Eng Specs (bigquery): Test extra table metadata
|
||||
"""
|
||||
database = mock.Mock()
|
||||
sql = "SELECT * FROM `table_name`"
|
||||
database.compile_sqla_query.return_value = sql
|
||||
tbl = Table("some_table", "some_schema")
|
||||
|
||||
# Test no indexes
|
||||
database.get_indexes = mock.MagicMock(return_value=None)
|
||||
result = BigQueryEngineSpec.get_extra_table_metadata(
|
||||
database,
|
||||
Table("some_table", "some_schema"),
|
||||
)
|
||||
mock_get_time_partition_column.return_value = None
|
||||
mock_get_max_partition_id.return_value = None
|
||||
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
|
||||
assert result == {}
|
||||
|
||||
index_metadata = [
|
||||
{
|
||||
"name": "clustering",
|
||||
"column_names": ["c_col1", "c_col2", "c_col3"],
|
||||
mock_get_time_partition_column.return_value = "ds"
|
||||
mock_get_max_partition_id.return_value = "19690101"
|
||||
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
|
||||
print(result)
|
||||
assert result == {
|
||||
"indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}],
|
||||
"partitions": {
|
||||
"cols": ["ds"],
|
||||
"latest": {"ds": "19690101"},
|
||||
"partitionQuery": sql,
|
||||
},
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["p_col1", "p_col2", "p_col3"],
|
||||
},
|
||||
]
|
||||
expected_result = {
|
||||
"partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
|
||||
"clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
|
||||
}
|
||||
database.get_indexes = mock.MagicMock(return_value=index_metadata)
|
||||
result = BigQueryEngineSpec.get_extra_table_metadata(
|
||||
database,
|
||||
Table("some_table", "some_schema"),
|
||||
)
|
||||
assert result == expected_result
|
||||
|
||||
def test_get_indexes(self):
|
||||
database = mock.Mock()
|
||||
inspector = mock.Mock()
|
||||
schema = "foo"
|
||||
table_name = "bar"
|
||||
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": [None],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert (
|
||||
BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
Table(table_name, schema),
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
Table(table_name, schema),
|
||||
) == [
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
|
||||
inspector.get_indexes = mock.Mock(
|
||||
return_value=[
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm", None],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert BigQueryEngineSpec.get_indexes(
|
||||
database,
|
||||
inspector,
|
||||
Table(table_name, schema),
|
||||
) == [
|
||||
{
|
||||
"name": "partition",
|
||||
"column_names": ["dttm"],
|
||||
"unique": False,
|
||||
}
|
||||
]
|
||||
|
||||
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
|
||||
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")
|
||||
|
||||
Reference in New Issue
Block a user