Files
superset2/superset/db_engine_specs/hive.py
Ville Brofeldt 1c4092c61c Add csv upload support for BigQuery (#7756)
* Add extra_require for bigquery to setup.py

* Refactor df_to_db and add df upload capability for BigQuery

* Fix unit tests and clarify kwarg logic

* Fix flake8 errors

* Add minimum versions for bigquery dependencies

* wrap to_gbq in try-catch block and raise error if pandas-gbq is missing

* Fix linting error and make error more generic
2019-06-23 14:20:09 -07:00

369 lines
14 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
import logging
import os
import re
import time
from typing import List
from urllib import parse
from sqlalchemy import Column
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql.expression import ColumnClause
from werkzeug.utils import secure_filename
from superset import app, conf
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.utils import core as utils
QueryStatus = utils.QueryStatus
config = app.config
tracking_url_trans = conf.get('TRACKING_URL_TRANSFORMER')
hive_poll_interval = conf.get('HIVE_POLL_INTERVAL')
class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality."""
engine = 'hive'
max_column_name_length = 767
# Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
jobs_stats_r = re.compile(
r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
# 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
launching_job_r = re.compile(
'.*INFO.*Launching Job (?P<job_number>[0-9]+) out of '
'(?P<max_jobs>[0-9]+)')
# 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
# map = 0%, reduce = 0%
stage_progress_r = re.compile(
r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*'
r'map = (?P<map_progress>[0-9]+)%.*'
r'reduce = (?P<reduce_progress>[0-9]+)%.*')
@classmethod
def patch(cls):
from pyhive import hive # pylint: disable=no-name-in-module
from superset.db_engines import hive as patched_hive
from TCLIService import (
constants as patched_constants,
ttypes as patched_ttypes,
TCLIService as patched_TCLIService)
hive.TCLIService = patched_TCLIService
hive.constants = patched_constants
hive.ttypes = patched_ttypes
hive.Cursor.fetch_logs = patched_hive.fetch_logs
@classmethod
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit):
import pyhive
from TCLIService import ttypes
state = cursor.poll()
if state.operationState == ttypes.TOperationState.ERROR_STATE:
raise Exception('Query error', state.errorMessage)
try:
return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
except pyhive.exc.ProgrammingError:
return []
@classmethod
def create_table_from_csv(cls, form, table):
"""Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type):
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
'boolean': 'BOOLEAN',
'integer': 'INT',
'number': 'DOUBLE',
'string': 'STRING',
}
return tableschema_to_hive_types.get(col_type, 'STRING')
bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
if not bucket_path:
logging.info('No upload bucket specified')
raise Exception(
'No upload bucket specified. You can specify one in the config file.')
table_name = form.name.data
schema_name = form.schema.data
if config.get('UPLOADED_CSV_HIVE_NAMESPACE'):
if '.' in table_name or schema_name:
raise Exception(
"You can't specify a namespace. "
'All tables will be uploaded to the `{}` namespace'.format(
config.get('HIVE_NAMESPACE')))
full_table_name = '{}.{}'.format(
config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
else:
if '.' in table_name and schema_name:
raise Exception(
"You can't specify a namespace both in the name of the table "
'and in the schema field. Please remove one')
full_table_name = '{}.{}'.format(
schema_name, table_name) if schema_name else table_name
filename = form.csv_file.data.filename
upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
upload_path = config['UPLOAD_FOLDER'] + \
secure_filename(filename)
# Optional dependency
from tableschema import Table # pylint: disable=import-error
hive_table_schema = Table(upload_path).infer()
column_name_and_type = []
for column_info in hive_table_schema['fields']:
column_name_and_type.append(
'`{}` {}'.format(
column_info['name'],
convert_to_hive_type(column_info['type'])))
schema_definition = ', '.join(column_name_and_type)
# Optional dependency
import boto3 # pylint: disable=import-error
s3 = boto3.client('s3')
location = os.path.join('s3a://', bucket_path, upload_prefix, table_name)
s3.upload_file(
upload_path, bucket_path,
os.path.join(upload_prefix, table_name, filename))
sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
TEXTFILE LOCATION '{location}'
tblproperties ('skip.header.line.count'='1')"""
logging.info(form.con.data)
engine = create_engine(form.con.data.sqlalchemy_uri_decrypted)
engine.execute(sql)
@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
if tt == 'DATE':
return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
elif tt == 'TIMESTAMP':
return "CAST('{}' AS TIMESTAMP)".format(
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = parse.quote(selected_schema, safe='')
return uri
@classmethod
def extract_error_message(cls, e):
msg = str(e)
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
if match:
msg = match.group(1)
return msg
@classmethod
def progress(cls, log_lines):
total_jobs = 1 # assuming there's at least 1 job
current_job = 1
stages = {}
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
total_jobs = int(match.groupdict()['max_jobs']) or 1
match = cls.launching_job_r.match(line)
if match:
current_job = int(match.groupdict()['job_number'])
total_jobs = int(match.groupdict()['max_jobs']) or 1
stages = {}
match = cls.stage_progress_r.match(line)
if match:
stage_number = int(match.groupdict()['stage_number'])
map_progress = int(match.groupdict()['map_progress'])
reduce_progress = int(match.groupdict()['reduce_progress'])
stages[stage_number] = (map_progress + reduce_progress) / 2
logging.info(
'Progress detail: {}, '
'current job {}, '
'total jobs: {}'.format(stages, current_job, total_jobs))
stage_progress = sum(
stages.values()) / len(stages.values()) if stages else 0
progress = (
100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
)
return int(progress)
@classmethod
def get_tracking_url(cls, log_lines):
lkp = 'Tracking URL = '
for line in log_lines:
if lkp in line:
return line.split(lkp)[1]
@classmethod
def handle_cursor(cls, cursor, query, session):
"""Updates progress information"""
from pyhive import hive # pylint: disable=no-name-in-module
unfinished_states = (
hive.ttypes.TOperationState.INITIALIZED_STATE,
hive.ttypes.TOperationState.RUNNING_STATE,
)
polled = cursor.poll()
last_log_line = 0
tracking_url = None
job_id = None
while polled.operationState in unfinished_states:
query = session.query(type(query)).filter_by(id=query.id).one()
if query.status == QueryStatus.STOPPED:
cursor.cancel()
break
log = cursor.fetch_logs() or ''
if log:
log_lines = log.splitlines()
progress = cls.progress(log_lines)
logging.info('Progress total: {}'.format(progress))
needs_commit = False
if progress > query.progress:
query.progress = progress
needs_commit = True
if not tracking_url:
tracking_url = cls.get_tracking_url(log_lines)
if tracking_url:
job_id = tracking_url.split('/')[-2]
logging.info(
'Found the tracking url: {}'.format(tracking_url))
tracking_url = tracking_url_trans(tracking_url)
logging.info(
'Transformation applied: {}'.format(tracking_url))
query.tracking_url = tracking_url
logging.info('Job id: {}'.format(job_id))
needs_commit = True
if job_id and len(log_lines) > last_log_line:
# Wait for job id before logging things out
# this allows for prefixing all log lines and becoming
# searchable in something like Kibana
for l in log_lines[last_log_line:]:
logging.info('[{}] {}'.format(job_id, l))
last_log_line = len(log_lines)
if needs_commit:
session.commit()
time.sleep(hive_poll_interval)
polled = cursor.poll()
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str) -> List[dict]:
return inspector.get_columns(table_name, schema)
@classmethod
def where_latest_partition(
cls, table_name, schema, database, qry, columns=None):
try:
col_name, value = cls.latest_partition(
table_name, schema, database, show_first=True)
except Exception:
# table is not partitioned
return False
if value is not None:
for c in columns:
if c.get('name') == col_name:
return qry.where(Column(col_name) == value)
return False
@classmethod
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
return BaseEngineSpec._get_fields(cols)
@classmethod
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
# TODO(bogdan): implement`
pass
@classmethod
def _latest_partition_from_df(cls, df):
"""Hive partitions look like ds={partition name}"""
if not df.empty:
return df.ix[:, 0].max().split('=')[1]
@classmethod
def _partition_query(
cls, table_name, limit=0, order_by=None, filters=None):
return f'SHOW PARTITIONS {table_name}'
@classmethod
def select_star(cls, my_db, table_name: str, engine: Engine, schema: str = None,
limit: int = 100, show_cols: bool = False, indent: bool = True,
latest_partition: bool = True, cols: List[dict] = []) -> str:
return BaseEngineSpec.select_star(
my_db, table_name, engine, schema, limit,
show_cols, indent, latest_partition, cols)
@classmethod
def modify_url_for_impersonation(cls, url, impersonate_user, username):
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
"""
# Do nothing in the URL object since instead this should modify
# the configuraiton dictionary. See get_configuration_for_impersonation
pass
@classmethod
def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
"""
Return a configuration dictionary that can be merged with other configs
that can set the correct properties for impersonating users
:param uri: URI string
:param impersonate_user: Bool indicating if impersonation is enabled
:param username: Effective username
:return: Dictionary with configs required for impersonation
"""
configuration = {}
url = make_url(uri)
backend_name = url.get_backend_name()
# Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS
if (backend_name == 'hive' and 'auth' in url.query.keys() and
impersonate_user is True and username is not None):
configuration['hive.server2.proxy.user'] = username
return configuration
@staticmethod
def execute(cursor, query, async_=False):
kwargs = {'async': async_}
cursor.execute(query, **kwargs)