refactor(example_data): replace the way the birth_names data is loaded to DB (#18060)

* refactor: replace the way the birth_names data is loaded to DB

* fix failed unit test

* fix failed unit test

* fix failed tests

* fix pass wrong flag of support datetime type

* remove unused fixture
This commit is contained in:
ofekisr
2022-01-18 23:21:04 +02:00
committed by GitHub
parent 88db2cc0ab
commit 4675ca31c5
29 changed files with 781 additions and 137 deletions

View File

@@ -14,12 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from typing import Callable, List, Optional
import pandas as pd
import pytest
from pandas import DataFrame
from sqlalchemy import DateTime, String
from superset import ConnectorRegistry, db
from superset.connectors.sqla.models import SqlaTable
@@ -28,9 +25,8 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
from tests.common.example_data_generator.birth_names.birth_names_generator_factory import (
BirthNamesGeneratorFactory,
)
from tests.example_data.data_loading.base_data_loader import DataLoader
from tests.example_data.data_loading.data_definitions.types import Table
from tests.integration_tests.dashboard_utils import create_table_metadata
from tests.integration_tests.test_app import app
@@ -38,31 +34,13 @@ BIRTH_NAMES_TBL_NAME = "birth_names"
@pytest.fixture(scope="session")
def load_birth_names_data():
with app.app_context():
database = get_example_database()
df = _get_dataframe(database)
dtype = {
"ds": DateTime if database.backend != "presto" else String(255),
"gender": String(16),
"state": String(10),
"name": String(255),
}
df.to_sql(
BIRTH_NAMES_TBL_NAME,
database.get_sqla_engine(),
if_exists="replace",
chunksize=500,
dtype=dtype,
index=False,
method="multi",
schema=get_example_default_schema(),
)
def load_birth_names_data(
birth_names_table_factory: Callable[[], Table], data_loader: DataLoader
):
birth_names_table: Table = birth_names_table_factory()
data_loader.load_table(birth_names_table)
yield
with app.app_context():
engine = get_example_database().get_sqla_engine()
engine.execute("DROP TABLE IF EXISTS birth_names")
data_loader.remove_table(birth_names_table.table_name)
@pytest.fixture()
@@ -137,15 +115,3 @@ def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
for slice_id in slices_ids:
db.session.query(Slice).filter_by(id=slice_id).delete()
db.session.commit()
def _get_dataframe(database: Database) -> DataFrame:
data = _get_birth_names_data()
df = pd.DataFrame.from_dict(data)
if database.backend == "presto":
df.ds = df.ds.dt.strftime("%Y-%m-%d %H:%M:%S")
return df
def _get_birth_names_data() -> List[Dict[Any, Any]]:
return list(BirthNamesGeneratorFactory.make().generate())