[fix] SQL parsing of table names (#7490)

This commit is contained in:
John Bodley
2019-06-03 11:07:57 -07:00
committed by GitHub
parent 78c1674dc7
commit 45b41aadcc
3 changed files with 73 additions and 15 deletions

View File

@@ -16,10 +16,12 @@
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Optional
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
@@ -75,11 +77,34 @@ class ParsedQuery(object):
return statements
@staticmethod
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()
def __get_full_name(tlist: TokenList) -> Optional[str]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
:param tlist: The SQL tokens
:returns: The valid full table name
"""
# Strip the alias if present.
idx = len(tlist.tokens)
if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)
if ws_idx != -1:
idx = ws_idx
tokens = tlist.tokens[:idx]
if (
len(tokens) in (1, 3, 5) and
all(imt(token, t=[Name, String]) for token in tokens[0::2]) and
all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2])
):
return '.'.join([remove_quotes(token.value) for token in tokens[0::2]])
return None
@staticmethod
def __is_identifier(token: Token):