Source code for ibmcloudsql.sql_magic

# ------------------------------------------------------------------------------
# Copyright IBM Corp. 2020
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
# flake8: noqa E203
import re
from functools import wraps

import sqlparse

    from exceptions import UnsupportedStorageFormatException
except Exception:
    from .exceptions import UnsupportedStorageFormatException

[docs]def format_sql(sql_stmt): """format SQL string to ensure proper content for string comparison Parameters ---------- sql_stmt: str """ url ="(cos)://[^\s]*", sql_stmt) mapping = {} key_template = "key_{num}" index = 0 while url: key = key_template.format(num=index) mapping[key] = sql_stmt = re.sub( r"(cos)://[^\s]*", "{" + key + "}", sql_stmt.rstrip(), count=1 ) url ="(cos)://[^\s]*", sql_stmt) index = index + 1 sql_stmt = sqlparse.format( sql_stmt, keyword_case="upper", strip_comments=True, reindent=True, identifier_case="lower", strip_whitespace=True, indent_columns=False, space_around_operators=False, output_format="sql", truncate_strings=None, indent_tabs=False, comma_first=False, right_margin=None, ) if mapping: try: sql_stmt = sql_stmt.format(**mapping) except KeyError as e: print(sql_stmt) raise e return sql_stmt
[docs]class TimeSeriesTransformInput: """ This class contains methods that supports the transformation of a user-friendly arguments in time-series functions to IBM CloudSQL compatible values. """
[docs] @classmethod def transform_sql(cls, f): """ Generate SQL string Notes ----- Syntax: """ @wraps(f) def wrapped(*args, **kwargs): self = args[0] self._sql_stmt = TimeSeriesTransformInput.ts_segment_by_time(self._sql_stmt) result = f(*args, **kwargs) return result return wrapped
[docs] @classmethod def ts_segment_by_time(cls, sql_stmt): """ Revise arguments of TS_SEGMENT_BY_TIME function to comply with IBM CloudSQL Notes ----- The TS_SEGMENT_BY_TIME supported by IBM CloudSQL accepts value in number, which is not user-friendly for units like hour, days, minutes. SQLBuilder alllows constructing SQL query string using the below values. 1. values: `per_hour`, `hour`, `per_day`, `day`, `per_week`, `week`, `minute`, `Xminute` (X is a number divisible by 60) .. code-block:: console ts_segment_by_time(ts, per_hour, per_hour) ts_segment_by_time(ts, hour, hour) 2. or values: using ISO 8601 P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W Examples -------- .. code-block:: python ts_segment_by_time(ts, PT1H, PT1H) into .. code-block:: python ts_segment_by_time(ts, 3600000, 3600000) as ts_segment_by_time operates at mili-seconds level, hour=60*60*1000 miliseconds """ sql_stmt = re.sub( r"ts_segment_by_time", "TS_SEGMENT_BY_TIME", sql_stmt, flags=re.IGNORECASE ) def handle_str_str(sql_stmt): h = re.compile( r"TS_SEGMENT_BY_TIME[\s]?\(([a-zA-Z0-9]+)[\s]?,[\s]?(?P<window>[0-9]*[_]?[a-zA-Z][a-zA-Z_0-9]+)[\s]?,[\s]?(?P<step>[0-9]*[_]?[a-zA-Z][a-zA-Z_0-9]+)", re.MULTILINE, ) return h def handle_str_number(sql_stmt): h = re.compile( r"TS_SEGMENT_BY_TIME[\s]?\(([a-zA-Z0-9]+)[\s]?,[\s]?(?P<window>[a-zA-Z][a-zA-Z_0-9]+)[\s]*,[\s]*(?P<step>[0-9]+)", re.MULTILINE, ) return h def handle_number_str(sql_stmt): h = re.compile( r"TS_SEGMENT_BY_TIME[\s]?\(([a-zA-Z0-9]+)[\s]?,[\s]?(?P<window>[0-9]+)[\s]?,[\s]?(?P<step>[a-zA-Z][a-zA-Z_0-9]+)", re.MULTILINE, ) return h def handle_result(h, sql_stmt): result = while result: start_end = {} num = {} for i in [2, 3]: if not sub_reg = re.compile( r"(?P<number>[0-9]*)[_]?(?P<unit>[a-zA-Z]+)" ) finding = sub_reg.match( number ="number") if len(number) == 0: number = 1 else: number = int(number) unit ="unit") # print(, number, unit) if "minute" in unit: num[i] = number assert 60 % num[i] == 0 elif ( in ["per_hour", "hour"] or "hour" in unit ): num[i] = 60 * number elif ( in ["per_day", "day"] or "day" in unit ): num[i] = 60 * 24 * number elif ( in ["per_week", "week"] or "week" in unit ): num[i] = 60 * 24 * 7 * number else: try: import isodate x = isodate.parse_duration( num[i] = int(x.total_seconds() / 60) # (minute) except ISO8601Error: print("%s unsupported" % assert 0 num[i] = num[i] * 60 * 1000 # (milliseconds) else: num[i] = start_end[i] = result.span(i) sql_stmt = ( sql_stmt[: start_end[2][0]] + str(num[2]) + sql_stmt[start_end[2][1] : start_end[3][0]] + str(num[3]) + sql_stmt[start_end[3][1] :] ) result = return sql_stmt h = handle_str_str(sql_stmt) sql_stmt = handle_result(h, sql_stmt) h = handle_number_str(sql_stmt) sql_stmt = handle_result(h, sql_stmt) h = handle_str_number(sql_stmt) sql_stmt = handle_result(h, sql_stmt) return sql_stmt
[docs]class TimeSeriesSchema: """ The class tracks the columns that is useful in time-series handling. Currently, it tracks column names whose values are in UNIX time format """ def __init__(self): # the list of column in unix-tme format """9999999999999 (13 digits) means Sat Nov 20 2286 17:46:39 UTC 999999999999 (12 digits) means Sun Sep 09 2001 01:46:39 UTC 99999999999 (11 digits) means Sat Mar 03 1973 09:46:39 UTC 100000000000000 (15 digits) means Wed Nov 16 5138 09:46:40 """ self._unixtime_columns = [] @property def columns_in_unixtime(self): """Return the name of columns whose values are in UNIX timestamp""" return self._unixtime_columns @columns_in_unixtime.setter def columns_in_unixtime(self, column_list): """Assign the name of columns whose values are in UNIX timestamp""" self._unixtime_columns = column_list
[docs]class SQLBuilder(TimeSeriesSchema): """ The class supports constructing a full SQL query statement """ def __init__(self): super().__init__() # contain the sql string that we can evoke using run() or submit() self._sql_stmt = "" self._has_stored_location = False self._has_with_clause = False self._has_select_clause = False self.supported_format_types = ["PARQUET", "CSV", "JSON"] self._has_from_clause = False self._current_vtable_name = ""
[docs] def print_sql(self): """print() sql string""" print_sql(self.get_sql())
[docs] @TimeSeriesTransformInput.transform_sql def get_sql(self): """Return the current sql string""" return format_sql(self._sql_stmt)
[docs] def reset_(self): """Reset and returns the current sql string""" res = self.get_sql() self._sql_stmt = "" self._has_stored_location = False self._has_with_clause = False self._has_select_clause = False self._has_from_clause = False return res
[docs] def with_(self, table_name, sql_stmt): """WITH <table> AS <sql> [, <table AS <sql>]""" if "WITH" not in self._sql_stmt: self._sql_stmt = ( self._sql_stmt + " WITH " + table_name + " AS (" + sql_stmt + "\n) " ) self._has_with_clause = True else: self._sql_stmt = ( self._sql_stmt + ", " + table_name + " AS (" + sql_stmt + "\n) " ) self._current_vtable_name = table_name return self
[docs] def select_(self, columns): """ SELECT <columns> Parameters --------------- columns: str a string representing a comma-separated list of columns """ assert self._has_select_clause is False self._sql_stmt = self._sql_stmt + "SELECT " + columns self._has_select_clause = True return self
[docs] def from_table_(self, table, alias=None): """ FROM <table> [AS alias] [, <table> [AS alias]] """ if self._has_from_clause: self._sql_stmt += ", " else: self._sql_stmt += " FROM " self._sql_stmt += table if alias: # NOTE: it's ok for not using 'AS' # self._sql_stmt += " AS " + alias.strip() self._sql_stmt += " " + alias.strip() self._has_from_clause = True return self
[docs] def from_cos_(self, cos_url, format_type="parquet", alias=None, delimiter=None): """ FROM <cos-url> STORED AS <format_type> AS type [FIELDS TERMINATED BY delimiter] [AS alias] [, <cos-url> AS type [AS alias]] """ format_type = format_type.strip().lower() if self._has_from_clause: self._sql_stmt += ", " else: self._sql_stmt += " FROM " self._sql_stmt += cos_url + " STORED AS " + format_type if delimiter is not None: assert format_type in ["csv", "textfile"] self._sql_stmt += " FIELD TERMINATED BY '{}'".format(delimiter) if alias: self._sql_stmt += " " + alias.strip() self._has_from_clause = True return self
[docs] def from_view_(self, sql_stmt): """ FROM (<sql>) """ self._sql_stmt = self._sql_stmt + " FROM ( \n" + sql_stmt + " \n) " return self
[docs] def where_(self, condition): """ WHERE <condition> [, <condition>] """ if "WHERE" not in self._sql_stmt: self._sql_stmt = self._sql_stmt + " WHERE " + condition else: self._sql_stmt = self._sql_stmt + ", " + condition self._has_from_clause = False return self
[docs] def join_cos_(self, cos_url, condition, typ="inner", alias=None): """ [typ] JOIN <cos-url> [AS alias] """ table = cos_url return self.join_table_(table, condition, typ=typ, alias=alias)
[docs] def join_table_(self, table, condition, typ="inner", alias=None): """ [typ] JOIN <table> [AS alias] ON <condition> NOTE: [typ] is a value in the list below .. code-block:: python ["INNER", "CROSS", "OUTER", "LEFT", "LEFT OUTER", "LEFT SEMI", "RIGHT", "RIGHT OUTER", "FULL", "FULL OUTER", "ANTI", "LEFT ANTI"] """ self.supported_join_types = [ "INNER", "CROSS", "OUTER", "LEFT", "LEFT OUTER", "LEFT SEMI", "RIGHT", "RIGHT OUTER", "FULL", "FULL OUTER", "ANTI", "LEFT ANTI", ] import re typ = re.sub(" +", " ", typ) if typ.upper() not in self.supported_join_types: msg = "Wrong 'typ', use a value in " + str(self.supported_join_types) raise ValueError(msg) if alias is None: self._sql_stmt = ( self._sql_stmt + " " + typ + " JOIN " + table + " ON " + condition ) else: self._sql_stmt = ( self._sql_stmt + " " + typ + " JOIN " + table + " AS " + alias + " ON " + condition ) return self
[docs] def order_by_(self, columns): """ ORDER BY <columns> Parameters --------------- condition: str a string representing a comma-separated list of columns """ self._sql_stmt = self._sql_stmt + " ORDER BY " + columns return self
[docs] def group_by_(self, columns): """ GROUP BY <columns> """ if "GROUP BY" not in self._sql_stmt: self._sql_stmt = self._sql_stmt + " GROUP BY " + columns else: self._sql_stmt = self._sql_stmt + ", " + columns return self
[docs] def store_at_(self, cos_url, format_type="CSV"): """ INTO <cos-url> STORED AS <type> """ if self._has_stored_location is False: self._sql_stmt = self._sql_stmt + " INTO " + cos_url self._has_stored_location = True else: assert 0 if len(format_type) > 0 and format_type.upper() in self.supported_format_types: self._sql_stmt = self._sql_stmt + " STORED AS " + format_type.upper() else: if format_type.upper() not in self.supported_format_types: raise UnsupportedStorageFormatException( "ERROR: unsupported type {}".format(format_type) ) return self
# def store_as_(self, format_type="parquet"): # if "INTO " not in self._sql_stmt: # self._sql_stmt = self._sql_stmt + " INTO {cos_out} STORED AS " + format_type.upper() # else: # self._sql_stmt = self._sql_stmt + " STORED AS " + format_type.upper() # return self
[docs] def partition_by_(self, columns): """ PARTITIONED BY <columns> """ if ( " PARTITION " not in self._sql_stmt and " PARTITIONED " not in self._sql_stmt ): self._sql_stmt += " PARTITIONED BY " + str(columns) else: assert 0 return self
[docs] def partition_objects_(self, num_objects): """ PARTITIONED INTO <num> OBJECTS """ if "PARTITION" not in self._sql_stmt: self._sql_stmt = ( self._sql_stmt + " PARTITIONED INTO " + str(num_objects) + " OBJECTS" ) else: assert 0 return self
[docs] def partition_rows_(self, num_rows): """ PARTITIONED INTO <num> ROWS """ if "PARTITION" not in self._sql_stmt: self._sql_stmt = ( self._sql_stmt + " PARTITIONED INTO " + str(num_rows) + " ROWS" ) else: assert 0 return self
[docs] @TimeSeriesTransformInput.transform_sql def format_(self): """Perform string replacement needed so that the final result is a SQL statement that is accepted by IBM Spark SQL """ return self