chore: use contextlib.surpress instead of passing on error (#24896)

Co-authored-by: John Bodley <4567245+john-bodley@users.noreply.github.com>
This commit is contained in:
Sebastian Liebscher
2023-08-29 18:09:01 +02:00
committed by GitHub
parent 72150ebadf
commit e585db85b6
18 changed files with 66 additions and 146 deletions

View File

@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
import contextlib
import json
import logging
from typing import Any, TYPE_CHECKING
@@ -223,11 +224,8 @@ class ChartDataRestApi(ChartRestApi):
json_body = request.json
elif request.form.get("form_data"):
# CSV export submits regular form data
try:
with contextlib.suppress(TypeError, json.JSONDecodeError):
json_body = json.loads(request.form["form_data"])
except (TypeError, json.JSONDecodeError):
pass
if json_body is None:
return self.response_400(message=_("Request is not JSON"))
@@ -324,14 +322,10 @@ class ChartDataRestApi(ChartRestApi):
Execute command as an async query.
"""
# First, look for the chart query results in the cache.
result = None
try:
with contextlib.suppress(ChartDataCacheLoadError):
result = command.run(force_cached=True)
if result is not None:
return self._send_chart_response(result)
except ChartDataCacheLoadError:
pass
# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
from typing import Any
from sqlalchemy import MetaData
@@ -221,14 +222,8 @@ def add_types(metadata: MetaData) -> None:
# add a tag for each object type
insert = tag.insert()
for type_ in ObjectTypes.__members__:
try:
db.session.execute(
insert,
name=f"type:{type_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"type:{type_}", type=TagTypes.type)
add_types_to_charts(metadata, tag, tagged_object, columns)
add_types_to_dashboards(metadata, tag, tagged_object, columns)
@@ -448,11 +443,8 @@ def add_owners(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"owner:{id_}", type=TagTypes.owner)
except IntegrityError:
pass # already exists
add_owners_to_charts(metadata, tag, tagged_object, columns)
add_owners_to_dashboards(metadata, tag, tagged_object, columns)
add_owners_to_saved_queries(metadata, tag, tagged_object, columns)
@@ -489,15 +481,8 @@ def add_favorites(metadata: MetaData) -> None:
ids = select([users.c.id])
insert = tag.insert()
for (id_,) in db.session.execute(ids):
try:
db.session.execute(
insert,
name=f"favorited_by:{id_}",
type=TagTypes.type,
)
except IntegrityError:
pass # already exists
with contextlib.suppress(IntegrityError): # already exists
db.session.execute(insert, name=f"favorited_by:{id_}", type=TagTypes.type)
favstars = (
select(
[

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import json
import re
import urllib
@@ -557,11 +558,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
except (json.JSONDecodeError, TypeError):
return encrypted_extra
try:
with contextlib.suppress(KeyError):
config["credentials_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass
return json.dumps(config)
@classmethod

View File

@@ -17,6 +17,7 @@
from __future__ import annotations
import contextlib
import json
import logging
import re
@@ -167,11 +168,8 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
except (TypeError, json.JSONDecodeError):
return encrypted_extra
try:
with contextlib.suppress(KeyError):
config["service_account_info"]["private_key"] = PASSWORD_MASK
except KeyError:
pass
return json.dumps(config)
@classmethod

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import re
from datetime import datetime
from re import Pattern
@@ -258,11 +259,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def _extract_error_message(cls, ex: Exception) -> str:
"""Extract error message for queries"""
message = str(ex)
try:
with contextlib.suppress(AttributeError, KeyError):
if isinstance(ex.args, tuple) and len(ex.args) > 1:
message = ex.args[1]
except (AttributeError, KeyError):
pass
return message
@classmethod

View File

@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import contextlib
import re
import threading
from re import Pattern
@@ -24,8 +25,7 @@ from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.orm import Session
# Need to try-catch here because pyocient may not be installed
try:
with contextlib.suppress(ImportError, RuntimeError): # pyocient may not be installed
# Ensure pyocient inherits Superset's logging level
import geojson
import pyocient
@@ -35,8 +35,6 @@ try:
superset_log_level = app.config["LOG_LEVEL"]
pyocient.logger.setLevel(superset_log_level)
except (ImportError, RuntimeError):
pass
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec

View File

@@ -17,6 +17,7 @@
# pylint: disable=too-many-lines
from __future__ import annotations
import contextlib
import logging
import re
import time
@@ -67,11 +68,8 @@ if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database
# need try/catch because pyhive may not be installed
try:
with contextlib.suppress(ImportError): # pyhive may not be installed
from pyhive.presto import Cursor
except ImportError:
pass
COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
"line (?P<location>.+?): .*Column '(?P<column_name>.+?)' cannot be resolved"
@@ -1274,12 +1272,10 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
with contextlib.suppress(AttributeError):
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
return f"{cursor._protocol}://{cursor._host}:{cursor._port}/ui/query.html?{cursor.last_query_id}"
except AttributeError:
pass
return None
@classmethod

View File

@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
import contextlib
import logging
from typing import Any, TYPE_CHECKING
@@ -35,10 +36,8 @@ from superset.utils import core as utils
if TYPE_CHECKING:
from superset.models.core import Database
try:
with contextlib.suppress(ImportError): # trino may not be installed
from trino.dbapi import Cursor
except ImportError:
pass
logger = logging.getLogger(__name__)
@@ -140,12 +139,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
try:
return cursor.info_uri
except AttributeError:
try:
with contextlib.suppress(AttributeError):
conn = cursor.connection
# pylint: disable=protected-access, line-too-long
return f"{conn.http_scheme}://{conn.host}:{conn.port}/ui/query.html?{cursor._query.query_id}"
except AttributeError:
pass
return None
@classmethod

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
from abc import ABC
from typing import Any, cast, Optional
@@ -107,17 +108,15 @@ class GetExploreCommand(BaseCommand, ABC):
)
except SupersetException:
self._datasource_id = None
# fallback unkonw datasource to table type
# fallback unknown datasource to table type
self._datasource_type = SqlaTable.type
datasource: Optional[BaseDatasource] = None
if self._datasource_id is not None:
try:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
db.session, cast(str, self._datasource_type), self._datasource_id
)
except DatasourceNotFound:
pass
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:

View File

@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
import contextlib
import logging
import os
import sys
@@ -25,7 +26,7 @@ import wtforms_json
from deprecation import deprecated
from flask import Flask, redirect
from flask_appbuilder import expose, IndexView
from flask_babel import gettext as __, lazy_gettext as _
from flask_babel import gettext as __
from flask_compress import Compress
from werkzeug.middleware.proxy_fix import ProxyFix
@@ -594,11 +595,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.superset_app.wsgi_app = ChunkedEncodingFix(self.superset_app.wsgi_app)
if self.config["UPLOAD_FOLDER"]:
try:
with contextlib.suppress(OSError):
os.makedirs(self.config["UPLOAD_FOLDER"])
except OSError:
pass
for middleware in self.config["ADDITIONAL_MIDDLEWARE"]:
self.superset_app.wsgi_app = middleware(self.superset_app.wsgi_app)

View File

@@ -24,7 +24,7 @@ import json
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext
from contextlib import closing, contextmanager, nullcontext, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
@@ -225,7 +225,6 @@ class Database(
@property
def allows_virtual_table_explore(self) -> bool:
extra = self.get_extra()
return bool(extra.get("allows_virtual_table_explore", True))
@property
@@ -235,9 +234,7 @@ class Database(
@property
def disable_data_preview(self) -> bool:
# this will prevent any 'trash value' strings from going through
if self.get_extra().get("disable_data_preview", False) is not True:
return False
return True
return self.get_extra().get("disable_data_preview", False) is True
@property
def data(self) -> dict[str, Any]:
@@ -285,11 +282,8 @@ class Database(
masked_uri = make_url_safe(self.sqlalchemy_uri)
encrypted_config = {}
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
try:
with suppress(TypeError, json.JSONDecodeError):
encrypted_config = json.loads(masked_encrypted_extra)
except (TypeError, json.JSONDecodeError):
pass
try:
# pylint: disable=useless-suppression
parameters = self.db_engine_spec.get_parameters_from_uri( # type: ignore
@@ -550,7 +544,7 @@ class Database(
@property
def quote_identifier(self) -> Callable[[str], str]:
"""Add quotes to potential identifiter expressions if needed"""
"""Add quotes to potential identifier expressions if needed"""
return self.get_dialect().identifier_preparer.quote
def get_reserved_words(self) -> set[str]:
@@ -692,7 +686,7 @@ class Database(
"""
try:
with self.get_inspector_with_context() as inspector:
tables = {
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
@@ -700,7 +694,6 @@ class Database(
schema=schema,
)
}
return tables
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@@ -985,7 +978,6 @@ sqla.event.listen(Database, "after_delete", security_manager.database_after_dele
class Log(Model): # pylint: disable=too-few-public-methods
"""ORM object used to log Superset actions to the database"""
__tablename__ = "logs"

View File

@@ -295,11 +295,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param schema: The Superset schema name
:return: The database specific schema permission
"""
if schema:
return f"[{database}].[{schema}]"
return None
return f"[{database}].[{schema}]" if schema else None
@staticmethod
def get_database_perm(database_id: int, database_name: str) -> str:
@@ -695,7 +691,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
Add the FAB permission/view-menu.
:param permission_name: The FAB permission name
:param view_menu_names: The FAB view-menu name
:param view_menu_name: The FAB view-menu name
:see: SecurityManager.add_permission_view_menu
"""
@@ -2163,8 +2159,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"aud": audience,
"type": "guest",
}
token = self.pyjwt_for_guest_token.encode(claims, secret, algorithm=algo)
return token
return self.pyjwt_for_guest_token.encode(claims, secret, algorithm=algo)
def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]:
"""
@@ -2230,9 +2225,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return hasattr(user, "is_guest_user") and user.is_guest_user
def get_current_guest_user_if_guest(self) -> Optional[GuestUser]:
if self.is_guest_user():
return g.user
return None
return g.user if self.is_guest_user() else None
def has_guest_access(self, dashboard: "Dashboard") -> bool:
user = self.get_current_guest_user_if_guest()
@@ -2293,8 +2286,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
Returns True if the current user is an owner of the resource, False otherwise.
:param resource: The dashboard, dataste, chart, etc. resource
:returns: Whethe the current user is an owner of the resource
:param resource: The dashboard, dataset, chart, etc. resource
:returns: Whether the current user is an owner of the resource
"""
try:
@@ -2308,7 +2301,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
Returns True if the current user is an admin user, False otherwise.
:returns: Whehther the current user is an admin user
:returns: Whether the current user is an admin user
"""
return current_app.config["AUTH_ROLE_ADMIN"] in [

View File

@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
import contextlib
import json
import logging
from dataclasses import dataclass
@@ -175,12 +176,10 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
)
def get_query_details(self) -> str:
try:
with contextlib.suppress(DetachedInstanceError):
if hasattr(self, "query"):
if self.query.id:
return f"query '{self.query.id}' - '{self.query.sql}'"
except DetachedInstanceError:
pass
return f"query '{self.sql}'"

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
from abc import ABC, abstractmethod
from typing import Any
@@ -54,7 +55,7 @@ class TemporaryCacheRestApi(BaseSupersetApi, ABC):
allow_browser_login = True
def add_apispec_components(self, api_spec: APISpec) -> None:
try:
with contextlib.suppress(DuplicateComponentNameError):
api_spec.components.schema(
TemporaryCachePostSchema.__name__,
schema=TemporaryCachePostSchema,
@@ -63,8 +64,6 @@ class TemporaryCacheRestApi(BaseSupersetApi, ABC):
TemporaryCachePutSchema.__name__,
schema=TemporaryCachePutSchema,
)
except DuplicateComponentNameError:
pass
super().add_apispec_components(api_spec)
@requires_json

View File

@@ -17,6 +17,7 @@
# pylint: disable=too-many-lines, invalid-name
from __future__ import annotations
import contextlib
import logging
from datetime import datetime
from typing import Any, Callable, cast
@@ -140,16 +141,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
form_data = parse.quote(json.dumps({"slice_id": slice_id}))
endpoint = f"/explore/?form_data={form_data}"
is_standalone_mode = ReservedUrlParameters.is_standalone_mode()
if is_standalone_mode:
if ReservedUrlParameters.is_standalone_mode():
endpoint += f"&{ReservedUrlParameters.STANDALONE}=true"
return redirect(endpoint)
def get_query_string_response(self, viz_obj: BaseViz) -> FlaskResponse:
query = None
try:
query_obj = viz_obj.query_obj()
if query_obj:
if query_obj := viz_obj.query_obj():
query = viz_obj.datasource.get_query_str(query_obj)
except Exception as ex: # pylint: disable=broad-except
err_msg = utils.error_msg_from_exception(ex)
@@ -304,7 +303,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
and response_type == ChartDataResultFormat.JSON
):
# First, look for the chart query results in the cache.
try:
with contextlib.suppress(CacheLoadError):
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
@@ -316,9 +315,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# If the chart query has already been cached, return it immediately.
if payload is not None:
return self.send_data_payload_response(viz_obj, payload)
except CacheLoadError:
pass
# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /explore_json/data/<cache_key>
@@ -411,8 +407,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
slice_id = parsed_form_data.get(
"slice_id", int(request.args.get("slice_id", 0))
)
datasource = parsed_form_data.get("datasource")
if datasource:
if datasource := parsed_form_data.get("datasource"):
datasource_id, datasource_type = datasource.split("__")
parameters = CommandParameters(
datasource_id=datasource_id,
@@ -431,9 +426,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# Return a relative URL
url = parse.urlparse(redirect_url)
if url.query:
return f"{url.path}?{url.query}"
return url.path
return f"{url.path}?{url.query}" if url.query else url.path
@has_access
@event_logger.log_this
@@ -468,8 +461,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
if key is not None:
command = GetExplorePermalinkCommand(key)
try:
permalink_value = command.run()
if permalink_value:
if permalink_value := command.run():
state = permalink_value["state"]
initial_form_data = state["formData"]
url_params = state.get("urlParams")
@@ -522,14 +514,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource: BaseDatasource | None = None
if datasource_id is not None:
try:
with contextlib.suppress(DatasetNotFoundError):
datasource = DatasourceDAO.get_datasource(
db.session,
DatasourceType("table"),
datasource_id,
)
except DatasetNotFoundError:
pass
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:
@@ -902,8 +893,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
if url_params := state.get("urlParams"):
params = parse.urlencode(url_params)
url = f"{url}&{params}"
hash_ = state.get("anchor", state.get("hash"))
if hash_:
if hash_ := state.get("anchor", state.get("hash")):
url = f"{url}#{hash_}"
return redirect(url)
@@ -960,12 +950,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return self.render_template("superset/public_welcome.html")
return redirect(appbuilder.get_url_for_login)
welcome_dashboard_id = (
if welcome_dashboard_id := (
db.session.query(UserAttribute.welcome_dashboard_id)
.filter_by(user_id=get_user_id())
.scalar()
)
if welcome_dashboard_id:
):
return self.dashboard(dashboard_id_or_slug=str(welcome_dashboard_id))
payload = {
@@ -1005,11 +994,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
}
if form_data := request.form.get("form_data"):
try:
with contextlib.suppress(json.JSONDecodeError):
payload["requested_query"] = json.loads(form_data)
except json.JSONDecodeError:
pass
payload["user"] = bootstrap_user_data(g.user, include_perms=True)
bootstrap_data = json.dumps(
payload, default=utils.pessimistic_json_iso_dttm_ser

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
from collections import defaultdict
from functools import wraps
@@ -323,7 +324,7 @@ def get_dashboard_extra_filters(
):
return []
try:
with contextlib.suppress(json.JSONDecodeError):
# does this dashboard have default filters?
json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
@@ -340,9 +341,6 @@ def get_dashboard_extra_filters(
and isinstance(default_filters, dict)
):
return build_extra_filters(layout, filter_scopes, default_filters, slice_id)
except json.JSONDecodeError:
pass
return []

View File

@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel
from contextlib import nullcontext
from contextlib import nullcontext, suppress
from typing import Optional, Union
import pandas as pd
@@ -164,11 +164,8 @@ def test_execute_query_failed_no_retry(mocker: MockFixture, app_context: None) -
command = AlertCommand(report_schedule=mocker.Mock())
try:
with suppress(AlertQueryTimeout):
command.validate()
except AlertQueryTimeout:
pass
assert execute_query_mock.call_count == 1
@@ -189,10 +186,7 @@ def test_execute_query_failed_max_retries(
command = AlertCommand(report_schedule=mocker.Mock())
try:
with suppress(AlertQueryError):
command.validate()
except AlertQueryError:
pass
# Should match the value defined in superset_test_config.py
assert execute_query_mock.call_count == 3

View File

@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
from typing import Callable, ContextManager
import pytest
@@ -42,20 +43,16 @@ def force_async_run(allow_run_async: bool):
def non_async_example_db(app_context):
gen = force_async_run(False)
yield next(gen)
try:
with contextlib.suppress(StopIteration):
next(gen)
except StopIteration:
pass
@pytest.fixture
def async_example_db(app_context):
gen = force_async_run(True)
yield next(gen)
try:
with contextlib.suppress(StopIteration):
next(gen)
except StopIteration:
pass
@pytest.fixture