mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
252 lines
8.9 KiB
Python
252 lines
8.9 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""Generic Parquet example data loader."""
|
|
|
|
import logging
|
|
from functools import partial
|
|
from typing import Any, Callable, Optional
|
|
|
|
import numpy as np
|
|
from sqlalchemy import inspect
|
|
|
|
from superset import db
|
|
from superset.connectors.sqla.models import SqlaTable
|
|
from superset.examples.helpers import read_example_data
|
|
from superset.models.core import Database
|
|
from superset.sql.parse import Table
|
|
from superset.utils import json
|
|
from superset.utils.database import get_example_database
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def serialize_numpy_arrays(obj: Any) -> Any: # noqa: C901
|
|
"""Convert numpy arrays to JSON-serializable format."""
|
|
if isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
elif isinstance(obj, np.generic):
|
|
# Handle numpy scalar types
|
|
return obj.item()
|
|
elif isinstance(obj, (list, tuple)):
|
|
return [serialize_numpy_arrays(item) for item in obj]
|
|
elif isinstance(obj, dict):
|
|
return {key: serialize_numpy_arrays(val) for key, val in obj.items()}
|
|
return obj
|
|
|
|
|
|
def load_parquet_table( # noqa: C901
|
|
parquet_file: str,
|
|
table_name: str,
|
|
database: Optional[Database] = None,
|
|
only_metadata: bool = False,
|
|
force: bool = False,
|
|
sample_rows: Optional[int] = None,
|
|
data_file: Optional[Any] = None,
|
|
schema: Optional[str] = None,
|
|
uuid: Optional[str] = None,
|
|
) -> SqlaTable:
|
|
"""Load a Parquet file into the example database.
|
|
|
|
Args:
|
|
parquet_file: Name of the Parquet file (e.g., "birth_names")
|
|
table_name: Name for the table in the target database
|
|
database: Target database (defaults to example database)
|
|
only_metadata: If True, only create metadata without loading data
|
|
force: If True, replace existing table
|
|
sample_rows: If specified, only load this many rows
|
|
data_file: Optional specific file path (Path object) to load from
|
|
schema: Schema to load into (defaults to database default schema)
|
|
|
|
Returns:
|
|
The created SqlaTable object
|
|
"""
|
|
from sqlalchemy import text
|
|
|
|
if database is None:
|
|
database = get_example_database()
|
|
|
|
# Determine schema - use provided or fall back to database default
|
|
with database.get_sqla_engine() as engine:
|
|
if schema is None:
|
|
schema = inspect(engine).default_schema_name
|
|
else:
|
|
# Create schema if it doesn't exist (PostgreSQL)
|
|
with engine.begin() as conn:
|
|
conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
|
|
|
table_exists = database.has_table(Table(table_name, schema=schema))
|
|
if table_exists and not force:
|
|
logger.info("Table %s already exists, skipping data load", table_name)
|
|
tbl = (
|
|
db.session.query(SqlaTable)
|
|
.filter_by(table_name=table_name, database_id=database.id)
|
|
.first()
|
|
)
|
|
if tbl:
|
|
return tbl
|
|
|
|
# Load data if not metadata only
|
|
if not only_metadata:
|
|
logger.info("Loading data for %s from %s.parquet", table_name, parquet_file)
|
|
|
|
# Read from Parquet - use specific file path if provided
|
|
if data_file is not None:
|
|
pdf = read_example_data(f"file://{data_file}")
|
|
else:
|
|
pdf = read_example_data(f"examples://{parquet_file}")
|
|
|
|
# Sample if requested (handle sample_rows=0 correctly)
|
|
if sample_rows is not None:
|
|
pdf = pdf.head(sample_rows)
|
|
|
|
# Check for columns with complex types (numpy arrays, nested structures)
|
|
for col in pdf.columns:
|
|
# Check if any value in the column is a numpy array or nested structure
|
|
if pdf[col].dtype == object:
|
|
try:
|
|
# Check if the first non-null value is complex
|
|
sample_val = (
|
|
pdf[col].dropna().iloc[0]
|
|
if not pdf[col].dropna().empty
|
|
else None
|
|
)
|
|
if sample_val is not None and isinstance(
|
|
sample_val, (np.ndarray, list, dict)
|
|
):
|
|
logger.info("Converting complex column %s to JSON string", col)
|
|
|
|
# Convert to JSON string for database storage
|
|
def safe_serialize(x: Any, column_name: str) -> Optional[str]:
|
|
if x is None:
|
|
return None
|
|
try:
|
|
return json.dumps(serialize_numpy_arrays(x))
|
|
except (TypeError, ValueError) as e:
|
|
logger.warning(
|
|
"Failed to serialize value in column %s: %s",
|
|
column_name,
|
|
e,
|
|
)
|
|
# Convert to string representation as fallback
|
|
return str(x)
|
|
|
|
# Avoid loop variable binding issues with partial
|
|
serialize_col = partial(safe_serialize, column_name=col)
|
|
pdf[col] = pdf[col].apply(serialize_col)
|
|
except Exception as e:
|
|
logger.warning("Could not process column %s: %s", col, e)
|
|
|
|
# Write to target database
|
|
with database.get_sqla_engine() as engine:
|
|
pdf.to_sql(
|
|
table_name,
|
|
engine,
|
|
schema=schema,
|
|
if_exists="replace",
|
|
chunksize=500,
|
|
method="multi",
|
|
index=False,
|
|
)
|
|
|
|
logger.info("Loaded %d rows into %s", len(pdf), table_name)
|
|
|
|
# Create or update SqlaTable metadata
|
|
tbl = (
|
|
db.session.query(SqlaTable)
|
|
.filter_by(table_name=table_name, database_id=database.id)
|
|
.first()
|
|
)
|
|
|
|
if not tbl:
|
|
tbl = SqlaTable(table_name=table_name, database_id=database.id)
|
|
# Set the database reference
|
|
tbl.database = database
|
|
|
|
# Set UUID from YAML config so the YAML import path can find this dataset
|
|
if uuid and not tbl.uuid:
|
|
tbl.uuid = uuid
|
|
|
|
if not only_metadata:
|
|
# Ensure database reference is set before fetching metadata
|
|
if not tbl.database:
|
|
tbl.database = database
|
|
tbl.fetch_metadata()
|
|
|
|
db.session.merge(tbl)
|
|
db.session.commit()
|
|
|
|
return tbl
|
|
|
|
|
|
def create_generic_loader(
|
|
parquet_file: str,
|
|
table_name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
sample_rows: Optional[int] = None,
|
|
data_file: Optional[Any] = None,
|
|
schema: Optional[str] = None,
|
|
uuid: Optional[str] = None,
|
|
) -> Callable[[Database, SqlaTable], None]:
|
|
"""Create a loader function for a specific Parquet file.
|
|
|
|
This factory function creates loaders that match the existing pattern
|
|
used by Superset examples.
|
|
|
|
Args:
|
|
parquet_file: Name of the Parquet file (without .parquet extension)
|
|
table_name: Table name (defaults to parquet_file)
|
|
description: Description for the dataset
|
|
sample_rows: Default number of rows to sample
|
|
data_file: Optional specific file path (Path object) for data/ folder pattern
|
|
schema: Schema to load into (defaults to database default schema)
|
|
|
|
Returns:
|
|
A loader function with the standard signature
|
|
"""
|
|
if table_name is None:
|
|
table_name = parquet_file
|
|
|
|
def loader(
|
|
only_metadata: bool = False,
|
|
force: bool = False,
|
|
sample: bool = False,
|
|
) -> None:
|
|
"""Load the dataset."""
|
|
rows = sample_rows if sample and sample_rows is not None else None
|
|
|
|
tbl = load_parquet_table(
|
|
parquet_file=parquet_file,
|
|
table_name=table_name,
|
|
only_metadata=only_metadata,
|
|
force=force,
|
|
sample_rows=rows,
|
|
data_file=data_file,
|
|
schema=schema,
|
|
uuid=uuid,
|
|
)
|
|
|
|
if description and tbl:
|
|
tbl.description = description
|
|
db.session.merge(tbl)
|
|
db.session.commit()
|
|
|
|
# Set function name and docstring
|
|
loader.__name__ = f"load_{parquet_file}"
|
|
loader.__doc__ = description or f"Load {parquet_file} dataset"
|
|
|
|
return loader
|