mirror of
https://github.com/apache/superset.git
synced 2026-04-20 16:44:46 +00:00
feat: new CSV upload form and API (#27840)
This commit is contained in:
committed by
GitHub
parent
40e77be813
commit
54387b4589
@@ -31,6 +31,7 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
|
||||
|
||||
from superset import app, event_logger
|
||||
from superset.commands.database.create import CreateDatabaseCommand
|
||||
from superset.commands.database.csv_import import CSVImportCommand
|
||||
from superset.commands.database.delete import DeleteDatabaseCommand
|
||||
from superset.commands.database.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
@@ -66,6 +67,7 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
|
||||
from superset.databases.decorators import check_table_access
|
||||
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
|
||||
from superset.databases.schemas import (
|
||||
CSVUploadPostSchema,
|
||||
database_schemas_query_schema,
|
||||
database_tables_query_schema,
|
||||
DatabaseConnectionSchema,
|
||||
@@ -130,6 +132,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
"delete_ssh_tunnel",
|
||||
"schemas_access_for_file_upload",
|
||||
"get_connection",
|
||||
"csv_upload",
|
||||
"oauth2",
|
||||
}
|
||||
|
||||
@@ -241,6 +244,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
|
||||
openapi_spec_tag = "Database"
|
||||
openapi_spec_component_schemas = (
|
||||
CSVUploadPostSchema,
|
||||
DatabaseConnectionSchema,
|
||||
DatabaseFunctionNamesResponse,
|
||||
DatabaseSchemaAccessForFileUploadResponse,
|
||||
@@ -1336,6 +1340,65 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
command.run()
|
||||
return self.response(200, message="OK")
|
||||
|
||||
@expose("/<int:pk>/csv_upload/", methods=("POST",))
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
@requires_form_data
|
||||
def csv_upload(self, pk: int) -> Response:
|
||||
"""Upload a CSV file into a database.
|
||||
---
|
||||
post:
|
||||
summary: Upload a CSV file to a database table
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CSVUploadPostSchema'
|
||||
responses:
|
||||
200:
|
||||
description: CSV upload response
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
message:
|
||||
type: string
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
request_form = request.form.to_dict()
|
||||
request_form["file"] = request.files.get("file")
|
||||
parameters = CSVUploadPostSchema().load(request_form)
|
||||
CSVImportCommand(
|
||||
pk,
|
||||
parameters["table_name"],
|
||||
parameters["file"],
|
||||
parameters,
|
||||
).run()
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
return self.response(200, message="OK")
|
||||
|
||||
@expose("/<int:pk>/function_names/", methods=("GET",))
|
||||
@protect()
|
||||
@safe
|
||||
|
||||
@@ -19,13 +19,24 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import EXCLUDE, fields, pre_load, Schema, validates_schema
|
||||
from marshmallow.validate import Length, ValidationError
|
||||
from marshmallow import (
|
||||
EXCLUDE,
|
||||
fields,
|
||||
post_load,
|
||||
pre_load,
|
||||
Schema,
|
||||
validates,
|
||||
validates_schema,
|
||||
)
|
||||
from marshmallow.validate import Length, OneOf, Range, ValidationError
|
||||
from sqlalchemy import MetaData
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from superset import db, is_feature_enabled
|
||||
from superset.commands.database.exceptions import DatabaseInvalidError
|
||||
@@ -980,6 +991,177 @@ class DatabaseConnectionSchema(Schema):
|
||||
)
|
||||
|
||||
|
||||
class DelimitedListField(fields.List):
|
||||
"""
|
||||
Special marshmallow field for handling delimited lists.
|
||||
formData expects a string, so we need to deserialize it into a list.
|
||||
"""
|
||||
|
||||
def _deserialize(
|
||||
self, value: str, attr: Any, data: Any, **kwargs: Any
|
||||
) -> list[Any]:
|
||||
try:
|
||||
values = value.split(",") if value else []
|
||||
return super()._deserialize(values, attr, data, **kwargs)
|
||||
except AttributeError as exc:
|
||||
raise ValidationError(
|
||||
f"{attr} is not a delimited list it has a non string value {value}."
|
||||
) from exc
|
||||
|
||||
|
||||
class CSVUploadPostSchema(Schema):
|
||||
"""
|
||||
Schema for CSV Upload
|
||||
"""
|
||||
|
||||
file = fields.Raw(
|
||||
required=True,
|
||||
metadata={
|
||||
"description": "The CSV file to upload",
|
||||
"type": "string",
|
||||
"format": "text/csv",
|
||||
},
|
||||
)
|
||||
delimiter = fields.String(metadata={"description": "The delimiter of the CSV file"})
|
||||
already_exists = fields.String(
|
||||
load_default="fail",
|
||||
validate=OneOf(choices=("fail", "replace", "append")),
|
||||
metadata={
|
||||
"description": "What to do if the table already "
|
||||
"exists accepts: fail, replace, append"
|
||||
},
|
||||
)
|
||||
column_data_types = fields.String(
|
||||
metadata={
|
||||
"description": "A dictionary with column names and "
|
||||
"their data types if you need to change "
|
||||
"the defaults. Example: {'user_id':'int'}. "
|
||||
"Check Python Pandas library for supported data types"
|
||||
}
|
||||
)
|
||||
column_dates = DelimitedListField(
|
||||
fields.String(),
|
||||
metadata={
|
||||
"description": "A list of column names that should be "
|
||||
"parsed as dates. Example: date,timestamp"
|
||||
},
|
||||
)
|
||||
column_labels = fields.String(
|
||||
metadata={
|
||||
"description": "Column label for index column(s). "
|
||||
"If None is given and Dataframe"
|
||||
"Index is checked, Index Names are used"
|
||||
}
|
||||
)
|
||||
columns_read = DelimitedListField(
|
||||
fields.String(),
|
||||
metadata={"description": "A List of the column names that should be read"},
|
||||
)
|
||||
dataframe_index = fields.String(
|
||||
metadata={
|
||||
"description": "Column to use as the row labels of the dataframe. "
|
||||
"Leave empty if no index column"
|
||||
}
|
||||
)
|
||||
day_first = fields.Boolean(
|
||||
metadata={
|
||||
"description": "DD/MM format dates, international and European format"
|
||||
}
|
||||
)
|
||||
decimal_character = fields.String(
|
||||
metadata={
|
||||
"description": "Character to recognize as decimal point. Default is '.'"
|
||||
}
|
||||
)
|
||||
header_row = fields.Integer(
|
||||
metadata={
|
||||
"description": "Row containing the headers to use as column names"
|
||||
"(0 is first line of data). Leave empty if there is no header row."
|
||||
}
|
||||
)
|
||||
index_column = fields.String(
|
||||
metadata={
|
||||
"description": "Column to use as the row labels of the dataframe. "
|
||||
"Leave empty if no index column"
|
||||
}
|
||||
)
|
||||
null_values = DelimitedListField(
|
||||
fields.String(),
|
||||
metadata={
|
||||
"description": "A list of strings that should be treated as null. "
|
||||
"Examples: '' for empty strings, 'None', 'N/A',"
|
||||
"Warning: Hive database supports only a single value"
|
||||
},
|
||||
)
|
||||
overwrite_duplicates = fields.Boolean(
|
||||
metadata={
|
||||
"description": "If duplicate columns are not overridden,"
|
||||
"they will be presented as 'X.1, X.2 ...X.x'."
|
||||
}
|
||||
)
|
||||
rows_to_read = fields.Integer(
|
||||
metadata={
|
||||
"description": "Number of rows to read from the file. "
|
||||
"If None, reads all rows."
|
||||
},
|
||||
allow_none=True,
|
||||
validate=Range(min=1),
|
||||
)
|
||||
schema = fields.String(
|
||||
metadata={"description": "The schema to upload the CSV file to."}
|
||||
)
|
||||
skip_blank_lines = fields.Boolean(
|
||||
metadata={"description": "Skip blank lines in the CSV file."}
|
||||
)
|
||||
skip_initial_space = fields.Boolean(
|
||||
metadata={"description": "Skip spaces after delimiter."}
|
||||
)
|
||||
skip_rows = fields.Integer(
|
||||
metadata={"description": "Number of rows to skip at start of file."}
|
||||
)
|
||||
table_name = fields.String(
|
||||
required=True,
|
||||
validate=[Length(min=1, max=10000)],
|
||||
allow_none=False,
|
||||
metadata={"description": "The name of the table to be created/appended"},
|
||||
)
|
||||
|
||||
@post_load
|
||||
def convert_column_data_types(
|
||||
self, data: dict[str, Any], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
if "column_data_types" in data and data["column_data_types"]:
|
||||
try:
|
||||
data["column_data_types"] = json.loads(data["column_data_types"])
|
||||
except json.JSONDecodeError as ex:
|
||||
raise ValidationError(
|
||||
"Invalid JSON format for column_data_types"
|
||||
) from ex
|
||||
return data
|
||||
|
||||
@validates("file")
|
||||
def validate_file_size(self, file: FileStorage) -> None:
|
||||
file.flush()
|
||||
size = os.fstat(file.fileno()).st_size
|
||||
if (
|
||||
current_app.config["CSV_UPLOAD_MAX_SIZE"] is not None
|
||||
and size > current_app.config["CSV_UPLOAD_MAX_SIZE"]
|
||||
):
|
||||
raise ValidationError([_("File size exceeds the maximum allowed size.")])
|
||||
|
||||
@validates("file")
|
||||
def validate_file_extension(self, file: FileStorage) -> None:
|
||||
allowed_extensions = current_app.config["ALLOWED_EXTENSIONS"].intersection(
|
||||
current_app.config["CSV_EXTENSIONS"]
|
||||
)
|
||||
matches = re.match(r".+\.([^.]+)$", file.filename)
|
||||
if not matches:
|
||||
raise ValidationError([_("File extension is not allowed.")])
|
||||
extension = matches.group(1)
|
||||
if extension not in allowed_extensions:
|
||||
raise ValidationError([_("File extension is not allowed.")])
|
||||
|
||||
|
||||
class OAuth2ProviderResponseSchema(Schema):
|
||||
"""
|
||||
Schema for the payload sent on OAuth2 redirect.
|
||||
|
||||
Reference in New Issue
Block a user