diff --git a/superset/config.py b/superset/config.py index 006a7ac78c4..1448129136a 100644 --- a/superset/config.py +++ b/superset/config.py @@ -302,6 +302,7 @@ GET_FEATURE_FLAGS_FUNC = None # --------------------------------------------------- # The file upload folder, when using models with files UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/" +UPLOAD_CHUNK_SIZE = 4096 # The image upload folder, when using models with images IMG_UPLOAD_FOLDER = BASE_DIR + "/app/static/uploads/" diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index ee3ccff702f..97e86fb9ff9 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -34,7 +34,6 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom from sqlalchemy.types import TypeEngine -from werkzeug.utils import secure_filename from superset import app, sql_parse from superset.utils import core as utils @@ -357,9 +356,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param kwargs: params to be passed to DataFrame.read_csv :return: Pandas DataFrame containing data from csv """ - kwargs["filepath_or_buffer"] = ( - config["UPLOAD_FOLDER"] + kwargs["filepath_or_buffer"] - ) kwargs["encoding"] = "utf-8" kwargs["iterator"] = True chunks = pd.read_csv(**kwargs) @@ -394,7 +390,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods extension is not None and extension[1:] in config["ALLOWED_EXTENSIONS"] ) - filename = secure_filename(form.csv_file.data.filename) + filename = form.csv_file.data.filename + if not _allowed_file(filename): raise Exception("Invalid file type selected") csv_to_df_kwargs = { diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 9a28715f69a..94912e9868b 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. import os +import tempfile +from typing import TYPE_CHECKING from flask import flash, g, redirect from flask_appbuilder import SimpleFormView from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext as _ -from werkzeug.utils import secure_filename from wtforms.fields import StringField from wtforms.validators import ValidationError @@ -35,6 +36,9 @@ from .forms import CsvToDatabaseForm from .mixins import DatabaseMixin from .validators import schema_allows_csv_upload, sqlalchemy_uri_validator +if TYPE_CHECKING: + from werkzeug.datastructures import FileStorage # pylint: disable=unused-import + config = app.config stats_logger = config["STATS_LOGGER"] @@ -46,6 +50,16 @@ def sqlalchemy_uri_form_validator(_, field: StringField) -> None: sqlalchemy_uri_validator(field.data, exception=ValidationError) +def upload_stream_write(form_file_field: "FileStorage", path: str): + chunk_size = app.config["UPLOAD_CHUNK_SIZE"] + with open(path, "bw") as file_description: + while True: + chunk = form_file_field.stream.read(chunk_size) + if not chunk: + break + file_description.write(chunk) + + class DatabaseView( DatabaseMixin, SupersetModelView, DeleteMixin, YamlExportMixin ): # pylint: disable=too-many-ancestors @@ -92,13 +106,16 @@ class CsvToDatabaseView(SimpleFormView): flash(message, "danger") return redirect("/csvtodatabaseview/form") - csv_file = form.csv_file.data - form.csv_file.data.filename = secure_filename(form.csv_file.data.filename) csv_filename = form.csv_file.data.filename - path = os.path.join(config["UPLOAD_FOLDER"], csv_filename) + extension = os.path.splitext(csv_filename)[1].lower() + path = tempfile.NamedTemporaryFile( + dir=app.config["UPLOAD_FOLDER"], suffix=extension, delete=False + ).name + form.csv_file.data.filename = path + try: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) - csv_file.save(path) + upload_stream_write(form.csv_file.data, path) table_name = form.name.data con = form.data.get("con") @@ -106,7 +123,6 @@ class CsvToDatabaseView(SimpleFormView): db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) database.db_engine_spec.create_table_from_csv(form, database) - table = ( db.session.query(SqlaTable) .filter_by(