From 2a72de4cd6a9acc40f66ac16557c4eac9094d048 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Tue, 31 Dec 2024 12:20:22 -0600 Subject: [PATCH] bug fixes: registry and handling cli parameters as well as adding warehousing --- bin/transport | 114 ++++++++++++++++++------ setup.py | 2 +- transport/__init__.py | 91 +++++++++++++------ transport/iowrapper.py | 20 +++-- transport/plugins/__init__.py | 109 ++++++++++++++--------- transport/providers/__init__.py | 8 +- transport/registry.py | 2 + transport/warehouse/__init__.py | 7 ++ transport/warehouse/drill.py | 55 ++++++++++++ transport/warehouse/iceberg.py | 151 ++++++++++++++++++++++++++++++++ 10 files changed, 458 insertions(+), 101 deletions(-) create mode 100644 transport/warehouse/__init__.py create mode 100644 transport/warehouse/drill.py create mode 100644 transport/warehouse/iceberg.py diff --git a/bin/transport b/bin/transport index 4053c4e..d2072f7 100755 --- a/bin/transport +++ b/bin/transport @@ -24,19 +24,25 @@ from multiprocessing import Process import os import transport -from transport import etl +# from transport import etl +from transport.iowrapper import IETL # from transport import providers import typer from typing_extensions import Annotated from typing import Optional import time from termcolor import colored +from enum import Enum +from rich import print app = typer.Typer() +app_x = typer.Typer() +app_i = typer.Typer() +app_r = typer.Typer() REGISTRY_PATH=os.sep.join([os.environ['HOME'],'.data-transport']) REGISTRY_FILE= 'transport-registry.json' -CHECK_MARK = ' '.join(['[',colored(u'\u2713', 'green'),']']) -TIMES_MARK= ' '.join(['[',colored(u'\u2717','red'),']']) +CHECK_MARK = '[ [green]\u2713[/green] ]' #' '.join(['[',colored(u'\u2713', 'green'),']']) +TIMES_MARK= '[ [red]\u2717[/red] ]' #' '.join(['[',colored(u'\u2717','red'),']']) # @app.command() def help() : print (__doc__) @@ -44,10 +50,15 @@ def wait(jobs): while jobs : jobs = [thread for thread in jobs if thread.is_alive()] time.sleep(1) +def wait (jobs): + while jobs : + jobs = [pthread for pthread in jobs if pthread.is_alive()] -@app.command(name="apply") +@app.command(name="etl") def apply (path:Annotated[str,typer.Argument(help="path of the configuration file")], - index:int = typer.Option(default= None, help="index of the item of interest, otherwise everything in the file will be processed")): + index:int = typer.Option(default= None, help="index of the item of interest, otherwise everything in the file will be processed"), + batch:int = typer.Option(default=5, help="The number of parallel processes to run at once") + ): """ This function applies data transport ETL feature to read data from one source to write it one or several others """ @@ -56,23 +67,34 @@ def apply (path:Annotated[str,typer.Argument(help="path of the configuration fil file = open(path) _config = json.loads (file.read() ) file.close() - if index : + if index is not None: _config = [_config[ int(index)]] - jobs = [] + jobs = [] for _args in _config : - pthread = etl.instance(**_args) #-- automatically starts the process + # pthread = etl.instance(**_args) #-- automatically starts the process + def bootup (): + _worker = IETL(**_args) + _worker.run() + pthread = Process(target=bootup) + pthread.start() jobs.append(pthread) + if len(jobs) == batch : + wait(jobs) + jobs = [] + + if jobs : + wait (jobs) # - # @TODO: Log the number of processes started and estimated time - while jobs : - jobs = [pthread for pthread in jobs if pthread.is_alive()] - time.sleep(1) + # @TODO: Log the number of processes started and estfrom transport impfrom transport impimated time + # while jobs : + # jobs = [pthread for pthread in jobs if pthread.is_alive()] + # time.sleep(1) # # @TODO: Log the job termination here ... -@app.command(name="providers") +@app_i.command(name="supported") def supported (format:Annotated[str,typer.Argument(help="format of the output, supported formats are (list,table,json)")]="table") : """ - This function will print supported providers/vendors and their associated classifications + This function will print supported database technologies """ _df = (transport.supported()) if format in ['list','json'] : @@ -81,13 +103,14 @@ def supported (format:Annotated[str,typer.Argument(help="format of the output, s print (_df) print () -@app.command() -def version(): +@app_i.command(name="license") +def info(): """ This function will display version and license information """ - print (transport.__app_name__,'version ',transport.__version__) + print (f'[bold] {transport.__app_name__} ,version {transport.__version__}[/bold]') + print () print (transport.__license__) @app.command() @@ -99,18 +122,18 @@ def generate (path:Annotated[str,typer.Argument(help="path of the ETL configurat { "source":{"provider":"http","url":"https://raw.githubusercontent.com/codeforamerica/ohana-api/master/data/sample-csv/addresses.csv"}, "target": - [{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite","database":"sample.db3","table":"addresses"}] + [{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite3","database":"sample.db3","table":"addresses"}] } ] file = open(path,'w') file.write(json.dumps(_config)) file.close() - print (f"""{CHECK_MARK} Successfully generated a template ETL file at {path}""" ) + print (f"""{CHECK_MARK} Successfully generated a template ETL file at [bold]{path}[/bold]""" ) print ("""NOTE: Each line (source or target) is the content of an auth-file""") -@app.command(name="init") +@app_r.command(name="reset") def initregistry (email:Annotated[str,typer.Argument(help="email")], path:str=typer.Option(default=REGISTRY_PATH,help="path or location of the configuration file"), override:bool=typer.Option(default=False,help="override existing configuration or not")): @@ -120,24 +143,24 @@ def initregistry (email:Annotated[str,typer.Argument(help="email")], """ try: transport.registry.init(email=email, path=path, override=override) - _msg = f"""{CHECK_MARK} Successfully wrote configuration to {path} from {email}""" + _msg = f"""{CHECK_MARK} Successfully wrote configuration to [bold]{path}[/bold] from [bold]{email}[/bold]""" except Exception as e: _msg = f"{TIMES_MARK} {e}" print (_msg) print () -@app.command(name="register") +@app_r.command(name="add") def register (label:Annotated[str,typer.Argument(help="unique label that will be used to load the parameters of the database")], auth_file:Annotated[str,typer.Argument(help="path of the auth_file")], default:bool=typer.Option(default=False,help="set the auth_file as default"), path:str=typer.Option(default=REGISTRY_PATH,help="path of the data-transport registry file")): """ - This function will register an auth-file i.e database connection and assign it a label, - Learn more about auth-file at https://healthcareio.the-phi.com/data-transport + This function add a database label for a given auth-file. which allows access to the database using a label of your choice. + """ try: if transport.registry.exists(path) : transport.registry.set(label=label,auth_file=auth_file, default=default, path=path) - _msg = f"""{CHECK_MARK} Successfully added label "{label}" to data-transport registry""" + _msg = f"""{CHECK_MARK} Successfully added label [bold]"{label}"[/bold] to data-transport registry""" else: _msg = f"""{TIMES_MARK} Registry is not initialized, please initialize the registry (check help)""" except Exception as e: @@ -145,6 +168,47 @@ def register (label:Annotated[str,typer.Argument(help="unique label that will be print (_msg) pass +@app_x.command(name='add') +def register_plugs ( + alias:Annotated[str,typer.Argument(help="unique alias fo the file being registered")], + path:Annotated[str,typer.Argument(help="path of the python file, that contains functions")] + ): + """ + This function will register a file and the functions within will be refrences . in a configuration file + """ + transport.registry.plugins.init() + _log = transport.registry.plugins.add(alias,path) + _mark = TIMES_MARK if not _log else CHECK_MARK + _msg = f"""Could NOT add the [bold]{alias}[/bold]to the registry""" if not _log else f""" successfully added {alias}, {len(_log)} functions added""" + print (f"""{_mark} {_msg}""") +@app_x.command(name="list") +def registry_list (): + + transport.registry.plugins.init() + _d = [] + for _alias in transport.registry.plugins._data : + _data = transport.registry.plugins._data[_alias] + _d += [{'alias':_alias,"plugin-count":len(_data['content']),'e.g':'@'.join([_alias,_data['content'][0]]),'plugins':json.dumps(_data['content'])}] + if _d: + print (pd.DataFrame(_d)) + else: + print (f"""{TIMES_MARK}, Plugin registry is not available or needs initialization""") + +@app_x.command(name="test") +def registry_test (key): + """ + This function allows to test syntax for a plugin i.e in terms of alias@function + """ + _item = transport.registry.plugins.has(key=key) + if _item : + del _item['pointer'] + print (f"""{CHECK_MARK} successfully loaded \033[1m{key}\033[0m found, version {_item['version']}""") + print (pd.DataFrame([_item])) + else: + print (f"{TIMES_MARK} unable to load \033[1m{key}\033[0m. Make sure it is registered") +app.add_typer(app_r,name='registry',help='This function allows labeling database access information') +app.add_typer(app_i,name="info",help="This function will print either license or supported database technologies") +app.add_typer(app_x, name="plugins",help="This function enables add/list/test of plugins in the registry") if __name__ == '__main__' : app() diff --git a/setup.py b/setup.py index 7bb44e8..f11a6ca 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ args = { "packages": find_packages(include=['info','transport', 'transport.*'])} args["keywords"]=['mongodb','duckdb','couchdb','rabbitmq','file','read','write','s3','sqlite'] -args["install_requires"] = ['pyncclient','duckdb-engine','pymongo','sqlalchemy','pandas','typer','pandas-gbq','numpy','cloudant','pika','nzpy','termcolor','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python','numpy','pymssql'] +args["install_requires"] = ['pyncclient','duckdb-engine','pymongo','sqlalchemy','pandas','typer','pandas-gbq','numpy','cloudant','pika','nzpy','termcolor','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python','numpy','pymssql','pyspark','pydrill','sqlalchemy_drill'] args["url"] = "https://healthcareio.the-phi.com/git/code/transport.git" args['scripts'] = ['bin/transport'] # if sys.version_info[0] == 2 : diff --git a/transport/__init__.py b/transport/__init__.py index b934760..33a3261 100644 --- a/transport/__init__.py +++ b/transport/__init__.py @@ -18,7 +18,7 @@ Source Code is available under MIT License: """ import numpy as np -from transport import sql, nosql, cloud, other +from transport import sql, nosql, cloud, other, warehouse import pandas as pd import json import os @@ -28,21 +28,26 @@ from transport.plugins import PluginLoader from transport import providers import copy from transport import registry - +from transport.plugins import Plugin PROVIDERS = {} def init(): global PROVIDERS - for _module in [cloud,sql,nosql,other] : + for _module in [cloud,sql,nosql,other,warehouse] : for _provider_name in dir(_module) : if _provider_name.startswith('__') or _provider_name == 'common': continue PROVIDERS[_provider_name] = {'module':getattr(_module,_provider_name),'type':_module.__name__} -def _getauthfile (path) : - f = open(path) - _object = json.loads(f.read()) - f.close() - return _object + # + # loading the registry + if not registry.isloaded() : + registry.load() + +# def _getauthfile (path) : +# f = open(path) +# _object = json.loads(f.read()) +# f.close() +# return _object def instance (**_args): """ This function returns an object of to read or write from a supported database provider/vendor @@ -52,16 +57,7 @@ def instance (**_args): kwargs These are arguments that are provider/vendor specific """ global PROVIDERS - # if not registry.isloaded () : - # if ('path' in _args and registry.exists(_args['path'] )) or registry.exists(): - # registry.load() if 'path' not in _args else registry.load(_args['path']) - # print ([' GOT IT']) - # if 'label' in _args and registry.isloaded(): - # _info = registry.get(_args['label']) - # if _info : - # # - # _args = dict(_args,**_info) - + if 'auth_file' in _args: if os.path.exists(_args['auth_file']) : # @@ -78,7 +74,7 @@ def instance (**_args): filename = _args['auth_file'] raise Exception(f" {filename} was not found or is invalid") if 'provider' not in _args and 'auth_file' not in _args : - if not registry.isloaded () : + if not registry.isloaded () : if ('path' in _args and registry.exists(_args['path'] )) or registry.exists(): registry.load() if 'path' not in _args else registry.load(_args['path']) _info = {} @@ -87,8 +83,6 @@ def instance (**_args): else: _info = registry.get() if _info : - # - # _args = dict(_args,**_info) _args = dict(_info,**_args) #-- we can override the registry parameters with our own arguments if 'provider' in _args and _args['provider'] in PROVIDERS : @@ -119,8 +113,32 @@ def instance (**_args): # for _delegate in _params : # loader.set(_delegate) - loader = None if 'plugins' not in _args else _args['plugins'] - return IReader(_agent,loader) if _context == 'read' else IWriter(_agent,loader) + _plugins = None if 'plugins' not in _args else _args['plugins'] + + # if registry.has('logger') : + # _kwa = registry.get('logger') + # _lmodule = getPROVIDERS[_kwa['provider']] + + if ( ('label' in _args and _args['label'] != 'logger') and registry.has('logger')): + # + # We did not request label called logger, so we are setting up a logger if it is specified in the registry + # + _kwargs = registry.get('logger') + _kwargs['context'] = 'write' + _kwargs['table'] =_module.__name__.split('.')[-1]+'_logs' + # _logger = instance(**_kwargs) + _module = PROVIDERS[_kwargs['provider']]['module'] + _logger = getattr(_module,'Writer') + _logger = _logger(**_kwargs) + else: + _logger = None + + _kwargs = {'agent':_agent,'plugins':_plugins,'logger':_logger} + if 'args' in _args : + _kwargs['args'] = _args['args'] + # _datatransport = IReader(_agent,_plugins,_logger) if _context == 'read' else IWriter(_agent,_plugins,_logger) + _datatransport = IReader(**_kwargs) if _context == 'read' else IWriter(**_kwargs) + return _datatransport else: # @@ -137,7 +155,14 @@ class get : if not _args or ('provider' not in _args and 'label' not in _args): _args['label'] = 'default' _args['context'] = 'read' - return instance(**_args) + # return instance(**_args) + # _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'}) + + _handler = instance(**_args) + # _handler.setLogger(get.logger()) + return _handler + + @staticmethod def writer(**_args): """ @@ -146,10 +171,26 @@ class get : if not _args or ('provider' not in _args and 'label' not in _args): _args['label'] = 'default' _args['context'] = 'write' - return instance(**_args) + # _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'}) + + _handler = instance(**_args) + # + # Implementing logging with the 'eat-your-own-dog-food' approach + # Using dependency injection to set the logger (problem with imports) + # + # _handler.setLogger(get.logger()) + return _handler + @staticmethod + def logger (): + if registry.has('logger') : + _args = registry.get('logger') + _args['context'] = 'write' + return instance(**_args) + return None @staticmethod def etl (**_args): if 'source' in _args and 'target' in _args : + return IETL(**_args) else: raise Exception ("Malformed input found, object must have both 'source' and 'target' attributes") diff --git a/transport/iowrapper.py b/transport/iowrapper.py index e3abf6c..e532e7d 100644 --- a/transport/iowrapper.py +++ b/transport/iowrapper.py @@ -5,7 +5,7 @@ NOTE: Plugins are converted to a pipeline, so we apply a pipeline when reading o - upon initialization we will load plugins - on read/write we apply a pipeline (if passed as an argument) """ -from transport.plugins import plugin, PluginLoader +from transport.plugins import Plugin, PluginLoader import transport from transport import providers from multiprocessing import Process @@ -16,7 +16,10 @@ class IO: """ Base wrapper class for read/write and support for logs """ - def __init__(self,_agent,plugins): + def __init__(self,**_args): + _agent = _args['agent'] + plugins = _args['plugins'] if 'plugins' not in _args else None + self._agent = _agent if plugins : self._init_plugins(plugins) @@ -63,8 +66,9 @@ class IReader(IO): """ This is a wrapper for read functionalities """ - def __init__(self,_agent,pipeline=None): - super().__init__(_agent,pipeline) + def __init__(self,**_args): + super().__init__(**_args) + def read(self,**_args): if 'plugins' in _args : self._init_plugins(_args['plugins']) @@ -75,8 +79,8 @@ class IReader(IO): # output data return _data class IWriter(IO): - def __init__(self,_agent,pipeline=None): - super().__init__(_agent,pipeline) + def __init__(self,**_args): #_agent,pipeline=None): + super().__init__(**_args) #_agent,pipeline) def write(self,_data,**_args): if 'plugins' in _args : self._init_plugins(_args['plugins']) @@ -94,7 +98,7 @@ class IETL(IReader) : This class performs an ETL operation by ineriting a read and adding writes as pipeline functions """ def __init__(self,**_args): - super().__init__(transport.get.reader(**_args['source'])) + super().__init__(agent=transport.get.reader(**_args['source']),plugins=None) if 'target' in _args: self._targets = _args['target'] if type(_args['target']) == list else [_args['target']] else: @@ -110,6 +114,8 @@ class IETL(IReader) : self.post(_data,**_kwargs) return _data + def run(self) : + return self.read() def post (self,_data,**_args) : """ This function returns an instance of a process that will perform the write operation diff --git a/transport/plugins/__init__.py b/transport/plugins/__init__.py index 26e5782..760b66c 100644 --- a/transport/plugins/__init__.py +++ b/transport/plugins/__init__.py @@ -11,8 +11,10 @@ import importlib as IL import importlib.util import sys import os +import pandas as pd +import time -class plugin : +class Plugin : """ Implementing function decorator for data-transport plugins (post-pre)-processing """ @@ -22,8 +24,9 @@ class plugin : :mode restrict to reader/writer :about tell what the function is about """ - self._name = _args['name'] - self._about = _args['about'] + self._name = _args['name'] if 'name' in _args else None + self._version = _args['version'] if 'version' in _args else '0.1' + self._doc = _args['doc'] if 'doc' in _args else "N/A" self._mode = _args['mode'] if 'mode' in _args else 'rw' def __call__(self,pointer,**kwargs): def wrapper(_args,**kwargs): @@ -32,57 +35,64 @@ class plugin : # @TODO: # add attributes to the wrapper object # + self._name = pointer.__name__ if not self._name else self._name setattr(wrapper,'transport',True) setattr(wrapper,'name',self._name) - setattr(wrapper,'mode',self._mode) - setattr(wrapper,'about',self._about) + setattr(wrapper,'version',self._version) + setattr(wrapper,'doc',self._doc) return wrapper - class PluginLoader : """ This class is intended to load a plugin and make it available and assess the quality of the developed plugin """ + def __init__(self,**_args): """ - :path location of the plugin (should be a single file) - :_names of functions to load """ - _names = _args['names'] if 'names' in _args else None - path = _args['path'] if 'path' in _args else None - self._names = _names if type(_names) == list else [_names] + # _names = _args['names'] if 'names' in _args else None + # path = _args['path'] if 'path' in _args else None + # self._names = _names if type(_names) == list else [_names] self._modules = {} self._names = [] - if path and os.path.exists(path) and _names: - for _name in self._names : - - spec = importlib.util.spec_from_file_location('private', path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) #--loads it into sys.modules - if hasattr(module,_name) : - if self.isplugin(module,_name) : - self._modules[_name] = getattr(module,_name) - else: - print ([f'Found {_name}', 'not plugin']) - else: - # - # @TODO: We should log this somewhere some how - print (['skipping ',_name, hasattr(module,_name)]) - pass - else: - # - # Initialization is empty - self._names = [] + self._registry = _args['registry'] + pass - def set(self,_pointer) : + def load (self,**_args): + self._modules = {} + self._names = [] + path = _args ['path'] + if os.path.exists(path) : + _alias = path.split(os.sep)[-1] + spec = importlib.util.spec_from_file_location(_alias, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) #--loads it into sys.modules + for _name in dir(module) : + if self.isplugin(module,_name) : + self._module[_name] = getattr(module,_name) + # self._names [_name] + def format (self,**_args): + uri = _args['alias'],_args['name'] + # def set(self,_pointer) : + def set(self,_key) : """ This function will set a pointer to the list of modules to be called This should be used within the context of using the framework as a library """ - _name = _pointer.__name__ + if type(_key).__name__ == 'function': + # + # The pointer is in the code provided by the user and loaded in memory + # + _pointer = _key + _key = 'inline@'+_key.__name__ + # self._names.append(_key.__name__) + else: + _pointer = self._registry.get(key=_key) + + if _pointer : + self._modules[_key] = _pointer + self._names.append(_key) - self._modules[_name] = _pointer - self._names.append(_name) def isplugin(self,module,name): """ This function determines if a module is a recognized plugin @@ -107,12 +117,31 @@ class PluginLoader : _n = len(self._names) return len(set(self._modules.keys()) & set (self._names)) / _n - def apply(self,_data): + def apply(self,_data,_logger=[]): + _input= {} + for _name in self._modules : - _pointer = self._modules[_name] - # - # @TODO: add exception handling - _data = _pointer(_data) + try: + _input = {'action':'plugin','object':_name,'input':{'status':'PASS'}} + _pointer = self._modules[_name] + if type(_data) == list : + _data = pd.DataFrame(_data) + _brow,_bcol = list(_data.shape) + + # + # @TODO: add exception handling + _data = _pointer(_data) + + _input['input']['shape'] = {'rows-dropped':_brow - _data.shape[0]} + except Exception as e: + _input['input']['status'] = 'FAILED' + print (e) + time.sleep(1) + if _logger: + try: + _logger(**_input) + except Exception as e: + pass return _data # def apply(self,_data,_name): # """ diff --git a/transport/providers/__init__.py b/transport/providers/__init__.py index 6422d74..b4cf37a 100644 --- a/transport/providers/__init__.py +++ b/transport/providers/__init__.py @@ -11,7 +11,7 @@ BIGQUERY ='bigquery' FILE = 'file' ETL = 'etl' -SQLITE = 'sqlite' +SQLITE = 'sqlite3' SQLITE3= 'sqlite3' DUCKDB = 'duckdb' @@ -44,7 +44,9 @@ PGSQL = POSTGRESQL AWS_S3 = 's3' RABBIT = RABBITMQ - - +ICEBERG='iceberg' +APACHE_ICEBERG = 'iceberg' +DRILL = 'drill' +APACHE_DRILL = 'drill' # QLISTENER = 'qlistener' \ No newline at end of file diff --git a/transport/registry.py b/transport/registry.py index f3dc8ac..1f612dc 100644 --- a/transport/registry.py +++ b/transport/registry.py @@ -220,6 +220,8 @@ def init (email,path=REGISTRY_PATH,override=False,_file=REGISTRY_FILE): def lookup (label): global DATA return label in DATA +has = lookup + def get (label='default') : global DATA return copy.copy(DATA[label]) if label in DATA else {} diff --git a/transport/warehouse/__init__.py b/transport/warehouse/__init__.py new file mode 100644 index 0000000..bcd76fd --- /dev/null +++ b/transport/warehouse/__init__.py @@ -0,0 +1,7 @@ +""" +This namespace/package is intended to handle read/writes against data warehouse solutions like : + - apache iceberg + - clickhouse (...) +""" + +from . import iceberg, drill \ No newline at end of file diff --git a/transport/warehouse/drill.py b/transport/warehouse/drill.py new file mode 100644 index 0000000..71f0e64 --- /dev/null +++ b/transport/warehouse/drill.py @@ -0,0 +1,55 @@ +import sqlalchemy +import pandas as pd +from .. sql.common import BaseReader , BaseWriter +import sqlalchemy as sqa + +class Drill : + __template = {'host':None,'port':None,'ssl':None,'table':None,'database':None} + def __init__(self,**_args): + + self._host = _args['host'] if 'host' in _args else 'localhost' + self._port = _args['port'] if 'port' in _args else self.get_default_port() + self._ssl = False if 'ssl' not in _args else _args['ssl'] + + self._table = _args['table'] if 'table' in _args else None + if self._table and '.' in self._table : + _seg = self._table.split('.') + if len(_seg) > 2 : + self._schema,self._database = _seg[:2] + else: + + self._database=_args['database'] + self._schema = self._database.split('.')[0] + + def _get_uri(self,**_args): + return f'drill+sadrill://{self._host}:{self._port}/{self._database}?use_ssl={self._ssl}' + def get_provider(self): + return "drill+sadrill" + def get_default_port(self): + return "8047" + def meta(self,**_args): + _table = _args['table'] if 'table' in _args else self._table + if '.' in _table : + _schema = _table.split('.')[:2] + _schema = '.'.join(_schema) + _table = _table.split('.')[-1] + else: + _schema = self._schema + + # _sql = f"select COLUMN_NAME AS name, CASE WHEN DATA_TYPE ='CHARACTER VARYING' THEN 'CHAR ( 125 )' ELSE DATA_TYPE END AS type from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{_schema}' and TABLE_NAME='{_table}'" + _sql = f"select COLUMN_NAME AS name, CASE WHEN DATA_TYPE ='CHARACTER VARYING' THEN 'CHAR ( '||COLUMN_SIZE||' )' ELSE DATA_TYPE END AS type from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{_schema}' and TABLE_NAME='{_table}'" + try: + _df = pd.read_sql(_sql,self._engine) + return _df.to_dict(orient='records') + except Exception as e: + print (e) + pass + return [] +class Reader (Drill,BaseReader) : + def __init__(self,**_args): + super().__init__(**_args) + self._chunksize = 0 if 'chunksize' not in _args else _args['chunksize'] + self._engine= sqa.create_engine(self._get_uri(),future=True) +class Writer(Drill,BaseWriter): + def __init__(self,**_args): + super().__init__(self,**_args) \ No newline at end of file diff --git a/transport/warehouse/iceberg.py b/transport/warehouse/iceberg.py new file mode 100644 index 0000000..4e73c62 --- /dev/null +++ b/transport/warehouse/iceberg.py @@ -0,0 +1,151 @@ +""" +dependency: + - spark and SPARK_HOME environment variable must be set +NOTE: + When using streaming option, insure that it is inline with default (1000 rows) or increase it in spark-defaults.conf + +""" +from pyspark.sql import SparkSession +from pyspark import SparkContext +from pyspark.sql.types import * +from pyspark.sql.functions import col, to_date, to_timestamp +import copy + +class Iceberg : + def __init__(self,**_args): + """ + providing catalog meta information (you must get this from apache iceberg) + """ + # + # Turning off logging (it's annoying & un-professional) + # + # _spconf = SparkContext() + # _spconf.setLogLevel("ERROR") + # + # @TODO: + # Make arrangements for additional configuration elements + # + self._session = SparkSession.builder.appName("data-transport").getOrCreate() + self._session.conf.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + # self._session.sparkContext.setLogLevel("ERROR") + self._catalog = self._session.catalog + self._table = _args['table'] if 'table' in _args else None + + if 'catalog' in _args : + # + # Let us set the default catalog + self._catalog.setCurrentCatalog(_args['catalog']) + + else: + # No current catalog has been set ... + pass + if 'database' in _args : + self._database = _args['database'] + self._catalog.setCurrentDatabase(self._database) + else: + # + # Should we set the default as the first one if available ? + # + pass + self._catalogName = self._catalog.currentCatalog() + self._databaseName = self._catalog.currentDatabase() + def meta (self,**_args) : + """ + This function should return the schema of a table (only) + """ + _schema = [] + try: + _table = _args['table'] if 'table' in _args else self._table + _tableName = self._getPrefix(**_args) + f".{_table}" + _tmp = self._session.table(_tableName).schema + _schema = _tmp.jsonValue()['fields'] + for _item in _schema : + del _item['nullable'],_item['metadata'] + except Exception as e: + + pass + return _schema + def _getPrefix (self,**_args): + _catName = self._catalogName if 'catalog' not in _args else _args['catalog'] + _datName = self._databaseName if 'database' not in _args else _args['database'] + + return '.'.join([_catName,_datName]) + def apply(self,_query): + """ + sql query/command to run against apache iceberg + """ + return self._session.sql(_query) + def has (self,**_args): + try: + _prefix = self._getPrefix(**_args) + if _prefix.endswith('.') : + return False + return _args['table'] in [_item.name for _item in self._catalog.listTables(_prefix)] + except Exception as e: + print (e) + return False + + def close(self): + self._session.stop() +class Reader(Iceberg) : + def __init__(self,**_args): + super().__init__(**_args) + def read(self,**_args): + _table = self._table + _prefix = self._getPrefix(**_args) + if 'table' in _args or _table: + _table = _args['table'] if 'table' in _args else _table + _table = _prefix + f'.{_table}' + return self._session.table(_table).toPandas() + else: + sql = _args['sql'] + return self._session.sql(sql).toPandas() + pass +class Writer (Iceberg): + """ + Writing data to an Apache Iceberg data warehouse (using pyspark) + """ + def __init__(self,**_args): + super().__init__(**_args) + self._mode = 'append' if 'mode' not in _args else _args['mode'] + self._table = None if 'table' not in _args else _args['table'] + def format (self,_schema) : + _iceSchema = StructType([]) + _map = {'integer':IntegerType(),'float':DoubleType(),'double':DoubleType(),'date':DateType(), + 'timestamp':TimestampType(),'datetime':TimestampType(),'string':StringType(),'varchar':StringType()} + for _item in _schema : + _name = _item['name'] + _type = _item['type'].lower() + if _type not in _map : + _iceType = StringType() + else: + _iceType = _map[_type] + + _iceSchema.add (StructField(_name,_iceType,True)) + return _iceSchema if len(_iceSchema) else [] + def write(self,_data,**_args): + _prefix = self._getPrefix(**_args) + if 'table' not in _args and not self._table : + raise Exception (f"Table Name should be specified for catalog/database {_prefix}") + _schema = self.format(_args['schema']) if 'schema' in _args else [] + if not _schema : + rdd = self._session.createDataFrame(_data,verifySchema=False) + else : + rdd = self._session.createDataFrame(_data,schema=_schema,verifySchema=True) + _mode = self._mode if 'mode' not in _args else _args['mode'] + _table = self._table if 'table' not in _args else _args['table'] + + # print (_data.shape,_mode,_table) + + if not self._session.catalog.tableExists(_table): + # # @TODO: + # # add partitioning information here + rdd.writeTo(_table).using('iceberg').create() + + # # _mode = 'overwrite' + # # rdd.write.format('iceberg').mode(_mode).saveAsTable(_table) + else: + # rdd.writeTo(_table).append() + # # _table = f'{_prefix}.{_table}' + + rdd.coalesce(10).write.format('iceberg').mode('append').save(_table)