feat: add latest partition support for BigQuery (#30760)

This commit is contained in:
Maxime Beauchemin
2025-04-01 17:13:09 -07:00
committed by GitHub
parent a36e636a58
commit c83eda9551
3 changed files with 150 additions and 144 deletions

View File

@@ -1630,6 +1630,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return: SqlAlchemy query with additional where clause referencing the latest
partition
"""
# TODO: Fix circular import caused by importing Database, TableColumn
return None
@classmethod

View File

@@ -27,15 +27,15 @@ from typing import Any, TYPE_CHECKING, TypedDict
import pandas as pd
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
from flask_babel import gettext as __
from marshmallow import fields, Schema
from marshmallow.exceptions import ValidationError
from sqlalchemy import column, types
from sqlalchemy import column, func, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql import column as sql_column, select, sqltypes
from sqlalchemy.sql.expression import table as sql_table
from superset.constants import TimeGrain
from superset.databases.schemas import encrypted_field_properties, EncryptedString
@@ -50,6 +50,11 @@ from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils, json
from superset.utils.hashing import md5_sha_from_str
if TYPE_CHECKING:
from sqlalchemy.sql.expression import Select
logger = logging.getLogger(__name__)
try:
import google.auth
from google.cloud import bigquery
@@ -289,42 +294,80 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return "_" + md5_sha_from_str(label)
@classmethod
@deprecated(deprecated_in="3.0")
def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
:param indexes: Raw indexes as returned by SQLAlchemy
:return: cleaner, more aligned index definition
"""
normalized_idxs = []
# Fixing a bug/behavior observed in pybigquery==0.4.15 where
# the index's `column_names` == [None]
# Here we're returning only non-None indexes
for ix in indexes:
column_names = ix.get("column_names") or []
ix["column_names"] = [col for col in column_names if col is not None]
if ix["column_names"]:
normalized_idxs.append(ix)
return normalized_idxs
@classmethod
def get_indexes(
def where_latest_partition(
cls,
database: Database,
inspector: Inspector,
table: Table,
) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
query: Select,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
if partition_column := cls.get_time_partition_column(database, table):
max_partition_id = cls.get_max_partition_id(database, table)
query = query.where(
column(partition_column) == func.PARSE_DATE("%Y%m%d", max_partition_id)
)
:param database: The database to inspect
:param inspector: The SQLAlchemy inspector
:param table: The table instance to inspect
:returns: The indexes
"""
return query
return cls.normalize_indexes(inspector.get_indexes(table.table, table.schema))
@classmethod
def get_max_partition_id(
cls,
database: Database,
table: Table,
) -> Select | None:
# Compose schema from catalog and schema
schema_parts = []
if table.catalog:
schema_parts.append(table.catalog)
if table.schema:
schema_parts.append(table.schema)
schema_parts.append("INFORMATION_SCHEMA")
schema = ".".join(schema_parts)
# Define a virtual table reference to INFORMATION_SCHEMA.PARTITIONS
partitions_table = sql_table(
"PARTITIONS",
sql_column("partition_id"),
sql_column("table_name"),
schema=schema,
)
# Build the query
query = select(
func.max(partitions_table.c.partition_id).label("max_partition_id")
).where(partitions_table.c.table_name == table.table)
# Compile to BigQuery SQL
compiled_query = query.compile(
dialect=database.get_dialect(),
compile_kwargs={"literal_binds": True},
)
# Run the query and handle result
with database.get_raw_connection(
catalog=table.catalog,
schema=table.schema,
) as conn:
cursor = conn.cursor()
cursor.execute(str(compiled_query))
if row := cursor.fetchone():
return row[0]
return None
@classmethod
def get_time_partition_column(
cls,
database: Database,
table: Table,
) -> str | None:
with cls.get_engine(
database, catalog=table.catalog, schema=table.schema
) as engine:
client = cls._get_client(engine, database)
bq_table = client.get_table(f"{table.schema}.{table.table}")
if bq_table.time_partitioning:
return bq_table.time_partitioning.field
return None
@classmethod
def get_extra_table_metadata(
@@ -332,23 +375,38 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
database: Database,
table: Table,
) -> dict[str, Any]:
indexes = database.get_indexes(table)
if not indexes:
return {}
partitions_columns = [
index.get("column_names", [])
for index in indexes
if index.get("name") == "partition"
]
cluster_columns = [
index.get("column_names", [])
for index in indexes
if index.get("name") == "clustering"
]
return {
"partitions": {"cols": partitions_columns},
"clustering": {"cols": cluster_columns},
}
payload = {}
partition_column = cls.get_time_partition_column(database, table)
with cls.get_engine(
database, catalog=table.catalog, schema=table.schema
) as engine:
if partition_column:
max_partition_id = cls.get_max_partition_id(database, table)
sql = cls.select_star(
database,
table,
engine,
indent=False,
show_cols=False,
latest_partition=True,
)
payload.update(
{
"partitions": {
"cols": [partition_column],
"latest": {partition_column: max_partition_id},
"partitionQuery": sql,
},
"indexes": [
{
"name": "partitioned",
"cols": [partition_column],
"type": "partitioned",
}
],
}
)
return payload
@classmethod
def epoch_to_dttm(cls) -> str:

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from contextlib import contextmanager
import pytest
from pandas import DataFrame
@@ -32,6 +33,15 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
)
@contextmanager
def mock_engine_with_credentials(*args, **kwargs):
engine_mock = mock.Mock()
engine_mock.dialect.credentials_info = {
"key": "value"
} # Add the credentials_info attribute
yield engine_mock
class TestBigQueryDbEngineSpec(TestDbEngineSpec):
def test_bigquery_sqla_column_label(self):
"""
@@ -111,108 +121,45 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
result = BigQueryEngineSpec.fetch_data(None, 0)
assert result == [1, 2]
def test_get_extra_table_metadata(self):
@mock.patch.object(
BigQueryEngineSpec, "get_engine", side_effect=mock_engine_with_credentials
)
@mock.patch.object(BigQueryEngineSpec, "get_time_partition_column")
@mock.patch.object(BigQueryEngineSpec, "get_max_partition_id")
@mock.patch.object(BigQueryEngineSpec, "quote_table", return_value="`table_name`")
def test_get_extra_table_metadata(
self,
mock_quote_table,
mock_get_max_partition_id,
mock_get_time_partition_column,
mock_get_engine,
):
"""
DB Eng Specs (bigquery): Test extra table metadata
"""
database = mock.Mock()
sql = "SELECT * FROM `table_name`"
database.compile_sqla_query.return_value = sql
tbl = Table("some_table", "some_schema")
# Test no indexes
database.get_indexes = mock.MagicMock(return_value=None)
result = BigQueryEngineSpec.get_extra_table_metadata(
database,
Table("some_table", "some_schema"),
)
mock_get_time_partition_column.return_value = None
mock_get_max_partition_id.return_value = None
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
assert result == {}
index_metadata = [
{
"name": "clustering",
"column_names": ["c_col1", "c_col2", "c_col3"],
mock_get_time_partition_column.return_value = "ds"
mock_get_max_partition_id.return_value = "19690101"
result = BigQueryEngineSpec.get_extra_table_metadata(database, tbl)
print(result)
assert result == {
"indexes": [{"cols": ["ds"], "name": "partitioned", "type": "partitioned"}],
"partitions": {
"cols": ["ds"],
"latest": {"ds": "19690101"},
"partitionQuery": sql,
},
{
"name": "partition",
"column_names": ["p_col1", "p_col2", "p_col3"],
},
]
expected_result = {
"partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]},
"clustering": {"cols": [["c_col1", "c_col2", "c_col3"]]},
}
database.get_indexes = mock.MagicMock(return_value=index_metadata)
result = BigQueryEngineSpec.get_extra_table_metadata(
database,
Table("some_table", "some_schema"),
)
assert result == expected_result
def test_get_indexes(self):
database = mock.Mock()
inspector = mock.Mock()
schema = "foo"
table_name = "bar"
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": [None],
"unique": False,
}
]
)
assert (
BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
)
== []
)
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
inspector.get_indexes = mock.Mock(
return_value=[
{
"name": "partition",
"column_names": ["dttm", None],
"unique": False,
}
]
)
assert BigQueryEngineSpec.get_indexes(
database,
inspector,
Table(table_name, schema),
) == [
{
"name": "partition",
"column_names": ["dttm"],
"unique": False,
}
]
@mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine")
@mock.patch("superset.db_engine_specs.bigquery.pandas_gbq")