mirror of
https://github.com/apache/superset.git
synced 2026-04-26 03:24:53 +00:00
refactor(example_data): replace the way the birth_names data is loaded to DB (#18060)
* refactor: replace the way the birth_names data is loaded to DB * fix failed unit test * fix failed unit test * fix failed tests * fix pass wrong flag of support datetime type * remove unused fixture
This commit is contained in:
@@ -14,12 +14,9 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
from sqlalchemy import DateTime, String
|
||||
|
||||
from superset import ConnectorRegistry, db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
@@ -28,9 +25,8 @@ from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.common.example_data_generator.birth_names.birth_names_generator_factory import (
|
||||
BirthNamesGeneratorFactory,
|
||||
)
|
||||
from tests.example_data.data_loading.base_data_loader import DataLoader
|
||||
from tests.example_data.data_loading.data_definitions.types import Table
|
||||
from tests.integration_tests.dashboard_utils import create_table_metadata
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
@@ -38,31 +34,13 @@ BIRTH_NAMES_TBL_NAME = "birth_names"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def load_birth_names_data():
|
||||
with app.app_context():
|
||||
database = get_example_database()
|
||||
df = _get_dataframe(database)
|
||||
dtype = {
|
||||
"ds": DateTime if database.backend != "presto" else String(255),
|
||||
"gender": String(16),
|
||||
"state": String(10),
|
||||
"name": String(255),
|
||||
}
|
||||
|
||||
df.to_sql(
|
||||
BIRTH_NAMES_TBL_NAME,
|
||||
database.get_sqla_engine(),
|
||||
if_exists="replace",
|
||||
chunksize=500,
|
||||
dtype=dtype,
|
||||
index=False,
|
||||
method="multi",
|
||||
schema=get_example_default_schema(),
|
||||
)
|
||||
def load_birth_names_data(
|
||||
birth_names_table_factory: Callable[[], Table], data_loader: DataLoader
|
||||
):
|
||||
birth_names_table: Table = birth_names_table_factory()
|
||||
data_loader.load_table(birth_names_table)
|
||||
yield
|
||||
with app.app_context():
|
||||
engine = get_example_database().get_sqla_engine()
|
||||
engine.execute("DROP TABLE IF EXISTS birth_names")
|
||||
data_loader.remove_table(birth_names_table.table_name)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -137,15 +115,3 @@ def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
|
||||
for slice_id in slices_ids:
|
||||
db.session.query(Slice).filter_by(id=slice_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def _get_dataframe(database: Database) -> DataFrame:
|
||||
data = _get_birth_names_data()
|
||||
df = pd.DataFrame.from_dict(data)
|
||||
if database.backend == "presto":
|
||||
df.ds = df.ds.dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return df
|
||||
|
||||
|
||||
def _get_birth_names_data() -> List[Dict[Any, Any]]:
|
||||
return list(BirthNamesGeneratorFactory.make().generate())
|
||||
|
||||
Reference in New Issue
Block a user