fix: adds the ability to disallow SQL functions per engine (#28639)

This commit is contained in:
Daniel Vaz Gaspar
2024-05-29 10:51:28 +01:00
committed by GitHub
parent 6575cacc5d
commit 5dfbab5424
7 changed files with 119 additions and 15 deletions

View File

@@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
@@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
return cte, remainder
def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param sql: The SQL statement
:param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement
"""
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
@@ -743,6 +757,34 @@ class ParsedQuery:
self._tables = self._extract_tables_from_sql()
return self._tables
def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False
def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.