Files
superset2/superset/examples/generic_loader.py
2026-02-02 09:24:16 -08:00

252 lines
8.9 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic Parquet example data loader."""
import logging
from functools import partial
from typing import Any, Callable, Optional
import numpy as np
from sqlalchemy import inspect
from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.examples.helpers import read_example_data
from superset.models.core import Database
from superset.sql.parse import Table
from superset.utils import json
from superset.utils.database import get_example_database
logger = logging.getLogger(__name__)
def serialize_numpy_arrays(obj: Any) -> Any: # noqa: C901
"""Convert numpy arrays to JSON-serializable format."""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic):
# Handle numpy scalar types
return obj.item()
elif isinstance(obj, (list, tuple)):
return [serialize_numpy_arrays(item) for item in obj]
elif isinstance(obj, dict):
return {key: serialize_numpy_arrays(val) for key, val in obj.items()}
return obj
def load_parquet_table( # noqa: C901
parquet_file: str,
table_name: str,
database: Optional[Database] = None,
only_metadata: bool = False,
force: bool = False,
sample_rows: Optional[int] = None,
data_file: Optional[Any] = None,
schema: Optional[str] = None,
uuid: Optional[str] = None,
) -> SqlaTable:
"""Load a Parquet file into the example database.
Args:
parquet_file: Name of the Parquet file (e.g., "birth_names")
table_name: Name for the table in the target database
database: Target database (defaults to example database)
only_metadata: If True, only create metadata without loading data
force: If True, replace existing table
sample_rows: If specified, only load this many rows
data_file: Optional specific file path (Path object) to load from
schema: Schema to load into (defaults to database default schema)
Returns:
The created SqlaTable object
"""
from sqlalchemy import text
if database is None:
database = get_example_database()
# Determine schema - use provided or fall back to database default
with database.get_sqla_engine() as engine:
if schema is None:
schema = inspect(engine).default_schema_name
else:
# Create schema if it doesn't exist (PostgreSQL)
with engine.begin() as conn:
conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
table_exists = database.has_table(Table(table_name, schema=schema))
if table_exists and not force:
logger.info("Table %s already exists, skipping data load", table_name)
tbl = (
db.session.query(SqlaTable)
.filter_by(table_name=table_name, database_id=database.id)
.first()
)
if tbl:
return tbl
# Load data if not metadata only
if not only_metadata:
logger.info("Loading data for %s from %s.parquet", table_name, parquet_file)
# Read from Parquet - use specific file path if provided
if data_file is not None:
pdf = read_example_data(f"file://{data_file}")
else:
pdf = read_example_data(f"examples://{parquet_file}")
# Sample if requested (handle sample_rows=0 correctly)
if sample_rows is not None:
pdf = pdf.head(sample_rows)
# Check for columns with complex types (numpy arrays, nested structures)
for col in pdf.columns:
# Check if any value in the column is a numpy array or nested structure
if pdf[col].dtype == object:
try:
# Check if the first non-null value is complex
sample_val = (
pdf[col].dropna().iloc[0]
if not pdf[col].dropna().empty
else None
)
if sample_val is not None and isinstance(
sample_val, (np.ndarray, list, dict)
):
logger.info("Converting complex column %s to JSON string", col)
# Convert to JSON string for database storage
def safe_serialize(x: Any, column_name: str) -> Optional[str]:
if x is None:
return None
try:
return json.dumps(serialize_numpy_arrays(x))
except (TypeError, ValueError) as e:
logger.warning(
"Failed to serialize value in column %s: %s",
column_name,
e,
)
# Convert to string representation as fallback
return str(x)
# Avoid loop variable binding issues with partial
serialize_col = partial(safe_serialize, column_name=col)
pdf[col] = pdf[col].apply(serialize_col)
except Exception as e:
logger.warning("Could not process column %s: %s", col, e)
# Write to target database
with database.get_sqla_engine() as engine:
pdf.to_sql(
table_name,
engine,
schema=schema,
if_exists="replace",
chunksize=500,
method="multi",
index=False,
)
logger.info("Loaded %d rows into %s", len(pdf), table_name)
# Create or update SqlaTable metadata
tbl = (
db.session.query(SqlaTable)
.filter_by(table_name=table_name, database_id=database.id)
.first()
)
if not tbl:
tbl = SqlaTable(table_name=table_name, database_id=database.id)
# Set the database reference
tbl.database = database
# Set UUID from YAML config so the YAML import path can find this dataset
if uuid and not tbl.uuid:
tbl.uuid = uuid
if not only_metadata:
# Ensure database reference is set before fetching metadata
if not tbl.database:
tbl.database = database
tbl.fetch_metadata()
db.session.merge(tbl)
db.session.commit()
return tbl
def create_generic_loader(
parquet_file: str,
table_name: Optional[str] = None,
description: Optional[str] = None,
sample_rows: Optional[int] = None,
data_file: Optional[Any] = None,
schema: Optional[str] = None,
uuid: Optional[str] = None,
) -> Callable[[Database, SqlaTable], None]:
"""Create a loader function for a specific Parquet file.
This factory function creates loaders that match the existing pattern
used by Superset examples.
Args:
parquet_file: Name of the Parquet file (without .parquet extension)
table_name: Table name (defaults to parquet_file)
description: Description for the dataset
sample_rows: Default number of rows to sample
data_file: Optional specific file path (Path object) for data/ folder pattern
schema: Schema to load into (defaults to database default schema)
Returns:
A loader function with the standard signature
"""
if table_name is None:
table_name = parquet_file
def loader(
only_metadata: bool = False,
force: bool = False,
sample: bool = False,
) -> None:
"""Load the dataset."""
rows = sample_rows if sample and sample_rows is not None else None
tbl = load_parquet_table(
parquet_file=parquet_file,
table_name=table_name,
only_metadata=only_metadata,
force=force,
sample_rows=rows,
data_file=data_file,
schema=schema,
uuid=uuid,
)
if description and tbl:
tbl.description = description
db.session.merge(tbl)
db.session.commit()
# Set function name and docstring
loader.__name__ = f"load_{parquet_file}"
loader.__doc__ = description or f"Load {parquet_file} dataset"
return loader