fix: save columns reference from sqllab save datasets flow (#24248)

This commit is contained in:
Hugh A. Miles II
2023-06-20 13:54:19 -04:00
committed by GitHub
parent fdef9cbc96
commit 93e1db4bd9
37 changed files with 489 additions and 202 deletions

View File

@@ -126,7 +126,14 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
type_ = group["type"].upper()
children_type = group["children"]
if type_ == "ARRAY":
return [{"name": column["name"], "type": children_type, "is_dttm": False}]
return [
{
"column_name": column["column_name"],
"name": column["column_name"],
"type": children_type,
"is_dttm": False,
}
]
if type_ == "ROW":
nameless_columns = 0
@@ -141,7 +148,8 @@ def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
type_ = parts[0]
nameless_columns += 1
_column: ResultSetColumnType = {
"name": f"{column['name']}.{name.lower()}",
"column_name": f"{column['column_name']}.{name.lower()}",
"name": f"{column['column_name']}.{name.lower()}",
"type": type_,
"is_dttm": False,
}
@@ -482,7 +490,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
schema: str | None,
database: Database,
query: Select,
columns: list[dict[str, Any]] | None = None,
columns: list[ResultSetColumnType] | None = None,
) -> Select | None:
try:
col_names, values = cls.latest_partition(
@@ -496,7 +504,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return None
column_type_by_name = {
column.get("name"): column.get("type") for column in columns or []
column.get("column_name"): column.get("type") for column in columns or []
}
for col_name, value in zip(col_names, values):
@@ -813,14 +821,20 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
) -> dict[str, Any]:
) -> ResultSetColumnType:
"""
Create column info object
:param name: column name
:param data_type: column data type
:return: column info object
"""
return {"name": name, "type": f"{data_type}"}
return {
"column_name": name,
"name": name,
"type": f"{data_type}",
"is_dttm": None,
"type_generic": None,
}
@classmethod
def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
@@ -863,7 +877,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
parent_column_name: str,
parent_data_type: str,
result: list[dict[str, Any]],
result: list[ResultSetColumnType],
) -> None:
"""
Parse a row or array column
@@ -941,7 +955,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# Unquote the column name if necessary
if formatted_parent_column_name != parent_column_name:
for index in range(original_result_len, len(result)):
result[index]["name"] = result[index]["name"].replace(
result[index]["column_name"] = result[index]["column_name"].replace(
formatted_parent_column_name, parent_column_name
)
@@ -965,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[dict[str, Any]]:
) -> list[ResultSetColumnType]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@@ -976,7 +990,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
result: list[dict[str, Any]] = []
result: list[ResultSetColumnType] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
@@ -1003,6 +1017,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
column_info = cls._create_column_info(column.Column, column_type)
column_info["nullable"] = getattr(column, "Null", True)
column_info["default"] = None
column_info["column_name"] = column.Column
result.append(column_info)
return result
@@ -1016,7 +1031,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@@ -1034,7 +1049,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
dot_regex = re.compile(dot_pattern, re.VERBOSE)
for col in cols:
# get individual column names
col_names = re.split(dot_regex, col["name"])
col_names = re.split(dot_regex, col["column_name"])
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
@@ -1044,7 +1059,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
for col_name in col_names
)
# create column clause in the format "name"."name" AS "name.name"
column_clause = literal_column(quoted_col_name).label(col["name"])
column_clause = literal_column(quoted_col_name).label(col["column_name"])
column_clauses.append(column_clause)
return column_clauses
@@ -1059,7 +1074,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
cols: list[dict[str, Any]] | None = None,
cols: list[ResultSetColumnType] | None = None,
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
@@ -1071,7 +1086,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
if is_feature_enabled("PRESTO_EXPAND_DATA") and show_cols:
dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
presto_cols = [
col for col in presto_cols if not re.search(dot_regex, col["name"])
col
for col in presto_cols
if not re.search(dot_regex, col["column_name"])
]
return super().select_star(
database,
@@ -1123,7 +1140,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
current_array_level = None
while to_process:
column, level = to_process.popleft()
if column["name"] not in [column["name"] for column in all_columns]:
if column["column_name"] not in [
column["column_name"] for column in all_columns
]:
all_columns.append(column)
# When unnesting arrays we need to keep track of how many extra rows
@@ -1135,7 +1154,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
unnested_rows: dict[int, int] = defaultdict(int)
current_array_level = level
name = column["name"]
name = column["column_name"]
values: str | list[Any] | None
if column["type"] and column["type"].startswith("ARRAY("):
@@ -1186,10 +1205,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
values = cast(Optional[list[Any]], destringify(values))
row[name] = values
for value, col in zip(values or [], expanded):
row[col["name"]] = value
row[col["column_name"]] = value
data = [
{k["name"]: row.get(k["name"], "") for k in all_columns} for row in data
{k["column_name"]: row.get(k["column_name"], "") for k in all_columns}
for row in data
]
return all_columns, data, expanded_columns