fix: is_select (#25189)

(cherry picked from commit 2f68010729)
This commit is contained in:
Beto Dealmeida
2023-09-06 11:54:25 -07:00
committed by Michael S. Molina
parent d8c72b86bc
commit 2ac03c3bcc
2 changed files with 47 additions and 34 deletions

View File

@@ -245,46 +245,52 @@ class ParsedQuery:
# make sure we strip comments; prevents a bug with comments in the CTE
parsed = sqlparse.parse(self.strip_comments())
# Check if this is a CTE
if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
if sqloxide_parse is not None:
try:
if not self._check_cte_is_select(
sqloxide_parse(self.strip_comments(), dialect="ansi")
):
return False
except ValueError:
# sqloxide was not able to parse the query, so let's continue with
# sqlparse
pass
inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
# Check if the inner CTE is a not a SELECT
if any(token.ttype == DDL for token in inner_cte) or any(
for statement in parsed:
# Check if this is a CTE
if statement.is_group and statement[0].ttype == Keyword.CTE:
if sqloxide_parse is not None:
try:
if not self._check_cte_is_select(
sqloxide_parse(self.strip_comments(), dialect="ansi")
):
return False
except ValueError:
# sqloxide was not able to parse the query, so let's continue with
# sqlparse
pass
inner_cte = self.get_inner_cte_expression(statement.tokens) or []
# Check if the inner CTE is a not a SELECT
if any(token.ttype == DDL for token in inner_cte) or any(
token.ttype == DML and token.normalized != "SELECT"
for token in inner_cte
):
return False
if statement.get_type() == "SELECT":
continue
if statement.get_type() != "UNKNOWN":
return False
# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
# and no DDL is allowed
if any(token.ttype == DDL for token in statement) or any(
token.ttype == DML and token.normalized != "SELECT"
for token in inner_cte
for token in statement
):
return False
if parsed[0].get_type() == "SELECT":
return True
# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
if statement[0].ttype == Keyword:
return False
if parsed[0].get_type() != "UNKNOWN":
return False
if not any(
token.ttype == DML and token.normalized == "SELECT"
for token in statement
):
return False
# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
# and no DDL is allowed
if any(token.ttype == DDL for token in parsed[0]) or any(
token.ttype == DML and token.normalized != "SELECT" for token in parsed[0]
):
return False
# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
if parsed[0][0].ttype == Keyword:
return False
return any(
token.ttype == DML and token.normalized == "SELECT" for token in parsed[0]
)
return True
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
for token in tokens: