feat: new CSV upload form and API (#27840)

This commit is contained in:
Daniel Vaz Gaspar
2024-04-15 09:38:51 +01:00
committed by GitHub
parent 40e77be813
commit 54387b4589
30 changed files with 2883 additions and 873 deletions

View File

@@ -31,6 +31,7 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from superset import app, event_logger
from superset.commands.database.create import CreateDatabaseCommand
from superset.commands.database.csv_import import CSVImportCommand
from superset.commands.database.delete import DeleteDatabaseCommand
from superset.commands.database.exceptions import (
DatabaseConnectionFailedError,
@@ -66,6 +67,7 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
from superset.databases.decorators import check_table_access
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
from superset.databases.schemas import (
CSVUploadPostSchema,
database_schemas_query_schema,
database_tables_query_schema,
DatabaseConnectionSchema,
@@ -130,6 +132,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"delete_ssh_tunnel",
"schemas_access_for_file_upload",
"get_connection",
"csv_upload",
"oauth2",
}
@@ -241,6 +244,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
openapi_spec_tag = "Database"
openapi_spec_component_schemas = (
CSVUploadPostSchema,
DatabaseConnectionSchema,
DatabaseFunctionNamesResponse,
DatabaseSchemaAccessForFileUploadResponse,
@@ -1336,6 +1340,65 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
command.run()
return self.response(200, message="OK")
@expose("/<int:pk>/csv_upload/", methods=("POST",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_",
log_to_statsd=False,
)
@requires_form_data
def csv_upload(self, pk: int) -> Response:
"""Upload a CSV file into a database.
---
post:
summary: Upload a CSV file to a database table
parameters:
- in: path
schema:
type: integer
name: pk
requestBody:
required: true
content:
multipart/form-data:
schema:
$ref: '#/components/schemas/CSVUploadPostSchema'
responses:
200:
description: CSV upload response
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
request_form = request.form.to_dict()
request_form["file"] = request.files.get("file")
parameters = CSVUploadPostSchema().load(request_form)
CSVImportCommand(
pk,
parameters["table_name"],
parameters["file"],
parameters,
).run()
except ValidationError as error:
return self.response_400(message=error.messages)
return self.response(200, message="OK")
@expose("/<int:pk>/function_names/", methods=("GET",))
@protect()
@safe

View File

@@ -19,13 +19,24 @@
import inspect
import json
import os
import re
from typing import Any
from flask import current_app
from flask_babel import lazy_gettext as _
from marshmallow import EXCLUDE, fields, pre_load, Schema, validates_schema
from marshmallow.validate import Length, ValidationError
from marshmallow import (
EXCLUDE,
fields,
post_load,
pre_load,
Schema,
validates,
validates_schema,
)
from marshmallow.validate import Length, OneOf, Range, ValidationError
from sqlalchemy import MetaData
from werkzeug.datastructures import FileStorage
from superset import db, is_feature_enabled
from superset.commands.database.exceptions import DatabaseInvalidError
@@ -980,6 +991,177 @@ class DatabaseConnectionSchema(Schema):
)
class DelimitedListField(fields.List):
"""
Special marshmallow field for handling delimited lists.
formData expects a string, so we need to deserialize it into a list.
"""
def _deserialize(
self, value: str, attr: Any, data: Any, **kwargs: Any
) -> list[Any]:
try:
values = value.split(",") if value else []
return super()._deserialize(values, attr, data, **kwargs)
except AttributeError as exc:
raise ValidationError(
f"{attr} is not a delimited list it has a non string value {value}."
) from exc
class CSVUploadPostSchema(Schema):
"""
Schema for CSV Upload
"""
file = fields.Raw(
required=True,
metadata={
"description": "The CSV file to upload",
"type": "string",
"format": "text/csv",
},
)
delimiter = fields.String(metadata={"description": "The delimiter of the CSV file"})
already_exists = fields.String(
load_default="fail",
validate=OneOf(choices=("fail", "replace", "append")),
metadata={
"description": "What to do if the table already "
"exists accepts: fail, replace, append"
},
)
column_data_types = fields.String(
metadata={
"description": "A dictionary with column names and "
"their data types if you need to change "
"the defaults. Example: {'user_id':'int'}. "
"Check Python Pandas library for supported data types"
}
)
column_dates = DelimitedListField(
fields.String(),
metadata={
"description": "A list of column names that should be "
"parsed as dates. Example: date,timestamp"
},
)
column_labels = fields.String(
metadata={
"description": "Column label for index column(s). "
"If None is given and Dataframe"
"Index is checked, Index Names are used"
}
)
columns_read = DelimitedListField(
fields.String(),
metadata={"description": "A List of the column names that should be read"},
)
dataframe_index = fields.String(
metadata={
"description": "Column to use as the row labels of the dataframe. "
"Leave empty if no index column"
}
)
day_first = fields.Boolean(
metadata={
"description": "DD/MM format dates, international and European format"
}
)
decimal_character = fields.String(
metadata={
"description": "Character to recognize as decimal point. Default is '.'"
}
)
header_row = fields.Integer(
metadata={
"description": "Row containing the headers to use as column names"
"(0 is first line of data). Leave empty if there is no header row."
}
)
index_column = fields.String(
metadata={
"description": "Column to use as the row labels of the dataframe. "
"Leave empty if no index column"
}
)
null_values = DelimitedListField(
fields.String(),
metadata={
"description": "A list of strings that should be treated as null. "
"Examples: '' for empty strings, 'None', 'N/A',"
"Warning: Hive database supports only a single value"
},
)
overwrite_duplicates = fields.Boolean(
metadata={
"description": "If duplicate columns are not overridden,"
"they will be presented as 'X.1, X.2 ...X.x'."
}
)
rows_to_read = fields.Integer(
metadata={
"description": "Number of rows to read from the file. "
"If None, reads all rows."
},
allow_none=True,
validate=Range(min=1),
)
schema = fields.String(
metadata={"description": "The schema to upload the CSV file to."}
)
skip_blank_lines = fields.Boolean(
metadata={"description": "Skip blank lines in the CSV file."}
)
skip_initial_space = fields.Boolean(
metadata={"description": "Skip spaces after delimiter."}
)
skip_rows = fields.Integer(
metadata={"description": "Number of rows to skip at start of file."}
)
table_name = fields.String(
required=True,
validate=[Length(min=1, max=10000)],
allow_none=False,
metadata={"description": "The name of the table to be created/appended"},
)
@post_load
def convert_column_data_types(
self, data: dict[str, Any], **kwargs: Any
) -> dict[str, Any]:
if "column_data_types" in data and data["column_data_types"]:
try:
data["column_data_types"] = json.loads(data["column_data_types"])
except json.JSONDecodeError as ex:
raise ValidationError(
"Invalid JSON format for column_data_types"
) from ex
return data
@validates("file")
def validate_file_size(self, file: FileStorage) -> None:
file.flush()
size = os.fstat(file.fileno()).st_size
if (
current_app.config["CSV_UPLOAD_MAX_SIZE"] is not None
and size > current_app.config["CSV_UPLOAD_MAX_SIZE"]
):
raise ValidationError([_("File size exceeds the maximum allowed size.")])
@validates("file")
def validate_file_extension(self, file: FileStorage) -> None:
allowed_extensions = current_app.config["ALLOWED_EXTENSIONS"].intersection(
current_app.config["CSV_EXTENSIONS"]
)
matches = re.match(r".+\.([^.]+)$", file.filename)
if not matches:
raise ValidationError([_("File extension is not allowed.")])
extension = matches.group(1)
if extension not in allowed_extensions:
raise ValidationError([_("File extension is not allowed.")])
class OAuth2ProviderResponseSchema(Schema):
"""
Schema for the payload sent on OAuth2 redirect.