bug fixes: registry and handling cli parameters as well as adding warehousing

v2.2.0
Steve Nyemba 3 weeks ago
parent d0e655e7e3
commit 2a72de4cd6

@ -24,19 +24,25 @@ from multiprocessing import Process
import os
import transport
from transport import etl
# from transport import etl
from transport.iowrapper import IETL
# from transport import providers
import typer
from typing_extensions import Annotated
from typing import Optional
import time
from termcolor import colored
from enum import Enum
from rich import print
app = typer.Typer()
app_x = typer.Typer()
app_i = typer.Typer()
app_r = typer.Typer()
REGISTRY_PATH=os.sep.join([os.environ['HOME'],'.data-transport'])
REGISTRY_FILE= 'transport-registry.json'
CHECK_MARK = ' '.join(['[',colored(u'\u2713', 'green'),']'])
TIMES_MARK= ' '.join(['[',colored(u'\u2717','red'),']'])
CHECK_MARK = '[ [green]\u2713[/green] ]' #' '.join(['[',colored(u'\u2713', 'green'),']'])
TIMES_MARK= '[ [red]\u2717[/red] ]' #' '.join(['[',colored(u'\u2717','red'),']'])
# @app.command()
def help() :
print (__doc__)
@ -44,10 +50,15 @@ def wait(jobs):
while jobs :
jobs = [thread for thread in jobs if thread.is_alive()]
time.sleep(1)
def wait (jobs):
while jobs :
jobs = [pthread for pthread in jobs if pthread.is_alive()]
@app.command(name="apply")
@app.command(name="etl")
def apply (path:Annotated[str,typer.Argument(help="path of the configuration file")],
index:int = typer.Option(default= None, help="index of the item of interest, otherwise everything in the file will be processed")):
index:int = typer.Option(default= None, help="index of the item of interest, otherwise everything in the file will be processed"),
batch:int = typer.Option(default=5, help="The number of parallel processes to run at once")
):
"""
This function applies data transport ETL feature to read data from one source to write it one or several others
"""
@ -56,23 +67,34 @@ def apply (path:Annotated[str,typer.Argument(help="path of the configuration fil
file = open(path)
_config = json.loads (file.read() )
file.close()
if index :
if index is not None:
_config = [_config[ int(index)]]
jobs = []
jobs = []
for _args in _config :
pthread = etl.instance(**_args) #-- automatically starts the process
# pthread = etl.instance(**_args) #-- automatically starts the process
def bootup ():
_worker = IETL(**_args)
_worker.run()
pthread = Process(target=bootup)
pthread.start()
jobs.append(pthread)
if len(jobs) == batch :
wait(jobs)
jobs = []
if jobs :
wait (jobs)
#
# @TODO: Log the number of processes started and estimated time
while jobs :
jobs = [pthread for pthread in jobs if pthread.is_alive()]
time.sleep(1)
# @TODO: Log the number of processes started and estfrom transport impfrom transport impimated time
# while jobs :
# jobs = [pthread for pthread in jobs if pthread.is_alive()]
# time.sleep(1)
#
# @TODO: Log the job termination here ...
@app.command(name="providers")
@app_i.command(name="supported")
def supported (format:Annotated[str,typer.Argument(help="format of the output, supported formats are (list,table,json)")]="table") :
"""
This function will print supported providers/vendors and their associated classifications
This function will print supported database technologies
"""
_df = (transport.supported())
if format in ['list','json'] :
@ -81,13 +103,14 @@ def supported (format:Annotated[str,typer.Argument(help="format of the output, s
print (_df)
print ()
@app.command()
def version():
@app_i.command(name="license")
def info():
"""
This function will display version and license information
"""
print (transport.__app_name__,'version ',transport.__version__)
print (f'[bold] {transport.__app_name__} ,version {transport.__version__}[/bold]')
print ()
print (transport.__license__)
@app.command()
@ -99,18 +122,18 @@ def generate (path:Annotated[str,typer.Argument(help="path of the ETL configurat
{
"source":{"provider":"http","url":"https://raw.githubusercontent.com/codeforamerica/ohana-api/master/data/sample-csv/addresses.csv"},
"target":
[{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite","database":"sample.db3","table":"addresses"}]
[{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite3","database":"sample.db3","table":"addresses"}]
}
]
file = open(path,'w')
file.write(json.dumps(_config))
file.close()
print (f"""{CHECK_MARK} Successfully generated a template ETL file at {path}""" )
print (f"""{CHECK_MARK} Successfully generated a template ETL file at [bold]{path}[/bold]""" )
print ("""NOTE: Each line (source or target) is the content of an auth-file""")
@app.command(name="init")
@app_r.command(name="reset")
def initregistry (email:Annotated[str,typer.Argument(help="email")],
path:str=typer.Option(default=REGISTRY_PATH,help="path or location of the configuration file"),
override:bool=typer.Option(default=False,help="override existing configuration or not")):
@ -120,24 +143,24 @@ def initregistry (email:Annotated[str,typer.Argument(help="email")],
"""
try:
transport.registry.init(email=email, path=path, override=override)
_msg = f"""{CHECK_MARK} Successfully wrote configuration to {path} from {email}"""
_msg = f"""{CHECK_MARK} Successfully wrote configuration to [bold]{path}[/bold] from [bold]{email}[/bold]"""
except Exception as e:
_msg = f"{TIMES_MARK} {e}"
print (_msg)
print ()
@app.command(name="register")
@app_r.command(name="add")
def register (label:Annotated[str,typer.Argument(help="unique label that will be used to load the parameters of the database")],
auth_file:Annotated[str,typer.Argument(help="path of the auth_file")],
default:bool=typer.Option(default=False,help="set the auth_file as default"),
path:str=typer.Option(default=REGISTRY_PATH,help="path of the data-transport registry file")):
"""
This function will register an auth-file i.e database connection and assign it a label,
Learn more about auth-file at https://healthcareio.the-phi.com/data-transport
This function add a database label for a given auth-file. which allows access to the database using a label of your choice.
"""
try:
if transport.registry.exists(path) :
transport.registry.set(label=label,auth_file=auth_file, default=default, path=path)
_msg = f"""{CHECK_MARK} Successfully added label "{label}" to data-transport registry"""
_msg = f"""{CHECK_MARK} Successfully added label [bold]"{label}"[/bold] to data-transport registry"""
else:
_msg = f"""{TIMES_MARK} Registry is not initialized, please initialize the registry (check help)"""
except Exception as e:
@ -145,6 +168,47 @@ def register (label:Annotated[str,typer.Argument(help="unique label that will be
print (_msg)
pass
@app_x.command(name='add')
def register_plugs (
alias:Annotated[str,typer.Argument(help="unique alias fo the file being registered")],
path:Annotated[str,typer.Argument(help="path of the python file, that contains functions")]
):
"""
This function will register a file and the functions within will be refrences <alias>.<function> in a configuration file
"""
transport.registry.plugins.init()
_log = transport.registry.plugins.add(alias,path)
_mark = TIMES_MARK if not _log else CHECK_MARK
_msg = f"""Could NOT add the [bold]{alias}[/bold]to the registry""" if not _log else f""" successfully added {alias}, {len(_log)} functions added"""
print (f"""{_mark} {_msg}""")
@app_x.command(name="list")
def registry_list ():
transport.registry.plugins.init()
_d = []
for _alias in transport.registry.plugins._data :
_data = transport.registry.plugins._data[_alias]
_d += [{'alias':_alias,"plugin-count":len(_data['content']),'e.g':'@'.join([_alias,_data['content'][0]]),'plugins':json.dumps(_data['content'])}]
if _d:
print (pd.DataFrame(_d))
else:
print (f"""{TIMES_MARK}, Plugin registry is not available or needs initialization""")
@app_x.command(name="test")
def registry_test (key):
"""
This function allows to test syntax for a plugin i.e in terms of alias@function
"""
_item = transport.registry.plugins.has(key=key)
if _item :
del _item['pointer']
print (f"""{CHECK_MARK} successfully loaded \033[1m{key}\033[0m found, version {_item['version']}""")
print (pd.DataFrame([_item]))
else:
print (f"{TIMES_MARK} unable to load \033[1m{key}\033[0m. Make sure it is registered")
app.add_typer(app_r,name='registry',help='This function allows labeling database access information')
app.add_typer(app_i,name="info",help="This function will print either license or supported database technologies")
app.add_typer(app_x, name="plugins",help="This function enables add/list/test of plugins in the registry")
if __name__ == '__main__' :
app()

@ -19,7 +19,7 @@ args = {
"packages": find_packages(include=['info','transport', 'transport.*'])}
args["keywords"]=['mongodb','duckdb','couchdb','rabbitmq','file','read','write','s3','sqlite']
args["install_requires"] = ['pyncclient','duckdb-engine','pymongo','sqlalchemy','pandas','typer','pandas-gbq','numpy','cloudant','pika','nzpy','termcolor','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python','numpy','pymssql']
args["install_requires"] = ['pyncclient','duckdb-engine','pymongo','sqlalchemy','pandas','typer','pandas-gbq','numpy','cloudant','pika','nzpy','termcolor','boto3','boto','pyarrow','google-cloud-bigquery','google-cloud-bigquery-storage','flask-session','smart_open','botocore','psycopg2-binary','mysql-connector-python','numpy','pymssql','pyspark','pydrill','sqlalchemy_drill']
args["url"] = "https://healthcareio.the-phi.com/git/code/transport.git"
args['scripts'] = ['bin/transport']
# if sys.version_info[0] == 2 :

@ -18,7 +18,7 @@ Source Code is available under MIT License:
"""
import numpy as np
from transport import sql, nosql, cloud, other
from transport import sql, nosql, cloud, other, warehouse
import pandas as pd
import json
import os
@ -28,21 +28,26 @@ from transport.plugins import PluginLoader
from transport import providers
import copy
from transport import registry
from transport.plugins import Plugin
PROVIDERS = {}
def init():
global PROVIDERS
for _module in [cloud,sql,nosql,other] :
for _module in [cloud,sql,nosql,other,warehouse] :
for _provider_name in dir(_module) :
if _provider_name.startswith('__') or _provider_name == 'common':
continue
PROVIDERS[_provider_name] = {'module':getattr(_module,_provider_name),'type':_module.__name__}
def _getauthfile (path) :
f = open(path)
_object = json.loads(f.read())
f.close()
return _object
#
# loading the registry
if not registry.isloaded() :
registry.load()
# def _getauthfile (path) :
# f = open(path)
# _object = json.loads(f.read())
# f.close()
# return _object
def instance (**_args):
"""
This function returns an object of to read or write from a supported database provider/vendor
@ -52,16 +57,7 @@ def instance (**_args):
kwargs These are arguments that are provider/vendor specific
"""
global PROVIDERS
# if not registry.isloaded () :
# if ('path' in _args and registry.exists(_args['path'] )) or registry.exists():
# registry.load() if 'path' not in _args else registry.load(_args['path'])
# print ([' GOT IT'])
# if 'label' in _args and registry.isloaded():
# _info = registry.get(_args['label'])
# if _info :
# #
# _args = dict(_args,**_info)
if 'auth_file' in _args:
if os.path.exists(_args['auth_file']) :
#
@ -78,7 +74,7 @@ def instance (**_args):
filename = _args['auth_file']
raise Exception(f" {filename} was not found or is invalid")
if 'provider' not in _args and 'auth_file' not in _args :
if not registry.isloaded () :
if not registry.isloaded () :
if ('path' in _args and registry.exists(_args['path'] )) or registry.exists():
registry.load() if 'path' not in _args else registry.load(_args['path'])
_info = {}
@ -87,8 +83,6 @@ def instance (**_args):
else:
_info = registry.get()
if _info :
#
# _args = dict(_args,**_info)
_args = dict(_info,**_args) #-- we can override the registry parameters with our own arguments
if 'provider' in _args and _args['provider'] in PROVIDERS :
@ -119,8 +113,32 @@ def instance (**_args):
# for _delegate in _params :
# loader.set(_delegate)
loader = None if 'plugins' not in _args else _args['plugins']
return IReader(_agent,loader) if _context == 'read' else IWriter(_agent,loader)
_plugins = None if 'plugins' not in _args else _args['plugins']
# if registry.has('logger') :
# _kwa = registry.get('logger')
# _lmodule = getPROVIDERS[_kwa['provider']]
if ( ('label' in _args and _args['label'] != 'logger') and registry.has('logger')):
#
# We did not request label called logger, so we are setting up a logger if it is specified in the registry
#
_kwargs = registry.get('logger')
_kwargs['context'] = 'write'
_kwargs['table'] =_module.__name__.split('.')[-1]+'_logs'
# _logger = instance(**_kwargs)
_module = PROVIDERS[_kwargs['provider']]['module']
_logger = getattr(_module,'Writer')
_logger = _logger(**_kwargs)
else:
_logger = None
_kwargs = {'agent':_agent,'plugins':_plugins,'logger':_logger}
if 'args' in _args :
_kwargs['args'] = _args['args']
# _datatransport = IReader(_agent,_plugins,_logger) if _context == 'read' else IWriter(_agent,_plugins,_logger)
_datatransport = IReader(**_kwargs) if _context == 'read' else IWriter(**_kwargs)
return _datatransport
else:
#
@ -137,7 +155,14 @@ class get :
if not _args or ('provider' not in _args and 'label' not in _args):
_args['label'] = 'default'
_args['context'] = 'read'
return instance(**_args)
# return instance(**_args)
# _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'})
_handler = instance(**_args)
# _handler.setLogger(get.logger())
return _handler
@staticmethod
def writer(**_args):
"""
@ -146,10 +171,26 @@ class get :
if not _args or ('provider' not in _args and 'label' not in _args):
_args['label'] = 'default'
_args['context'] = 'write'
return instance(**_args)
# _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'})
_handler = instance(**_args)
#
# Implementing logging with the 'eat-your-own-dog-food' approach
# Using dependency injection to set the logger (problem with imports)
#
# _handler.setLogger(get.logger())
return _handler
@staticmethod
def logger ():
if registry.has('logger') :
_args = registry.get('logger')
_args['context'] = 'write'
return instance(**_args)
return None
@staticmethod
def etl (**_args):
if 'source' in _args and 'target' in _args :
return IETL(**_args)
else:
raise Exception ("Malformed input found, object must have both 'source' and 'target' attributes")

@ -5,7 +5,7 @@ NOTE: Plugins are converted to a pipeline, so we apply a pipeline when reading o
- upon initialization we will load plugins
- on read/write we apply a pipeline (if passed as an argument)
"""
from transport.plugins import plugin, PluginLoader
from transport.plugins import Plugin, PluginLoader
import transport
from transport import providers
from multiprocessing import Process
@ -16,7 +16,10 @@ class IO:
"""
Base wrapper class for read/write and support for logs
"""
def __init__(self,_agent,plugins):
def __init__(self,**_args):
_agent = _args['agent']
plugins = _args['plugins'] if 'plugins' not in _args else None
self._agent = _agent
if plugins :
self._init_plugins(plugins)
@ -63,8 +66,9 @@ class IReader(IO):
"""
This is a wrapper for read functionalities
"""
def __init__(self,_agent,pipeline=None):
super().__init__(_agent,pipeline)
def __init__(self,**_args):
super().__init__(**_args)
def read(self,**_args):
if 'plugins' in _args :
self._init_plugins(_args['plugins'])
@ -75,8 +79,8 @@ class IReader(IO):
# output data
return _data
class IWriter(IO):
def __init__(self,_agent,pipeline=None):
super().__init__(_agent,pipeline)
def __init__(self,**_args): #_agent,pipeline=None):
super().__init__(**_args) #_agent,pipeline)
def write(self,_data,**_args):
if 'plugins' in _args :
self._init_plugins(_args['plugins'])
@ -94,7 +98,7 @@ class IETL(IReader) :
This class performs an ETL operation by ineriting a read and adding writes as pipeline functions
"""
def __init__(self,**_args):
super().__init__(transport.get.reader(**_args['source']))
super().__init__(agent=transport.get.reader(**_args['source']),plugins=None)
if 'target' in _args:
self._targets = _args['target'] if type(_args['target']) == list else [_args['target']]
else:
@ -110,6 +114,8 @@ class IETL(IReader) :
self.post(_data,**_kwargs)
return _data
def run(self) :
return self.read()
def post (self,_data,**_args) :
"""
This function returns an instance of a process that will perform the write operation

@ -11,8 +11,10 @@ import importlib as IL
import importlib.util
import sys
import os
import pandas as pd
import time
class plugin :
class Plugin :
"""
Implementing function decorator for data-transport plugins (post-pre)-processing
"""
@ -22,8 +24,9 @@ class plugin :
:mode restrict to reader/writer
:about tell what the function is about
"""
self._name = _args['name']
self._about = _args['about']
self._name = _args['name'] if 'name' in _args else None
self._version = _args['version'] if 'version' in _args else '0.1'
self._doc = _args['doc'] if 'doc' in _args else "N/A"
self._mode = _args['mode'] if 'mode' in _args else 'rw'
def __call__(self,pointer,**kwargs):
def wrapper(_args,**kwargs):
@ -32,57 +35,64 @@ class plugin :
# @TODO:
# add attributes to the wrapper object
#
self._name = pointer.__name__ if not self._name else self._name
setattr(wrapper,'transport',True)
setattr(wrapper,'name',self._name)
setattr(wrapper,'mode',self._mode)
setattr(wrapper,'about',self._about)
setattr(wrapper,'version',self._version)
setattr(wrapper,'doc',self._doc)
return wrapper
class PluginLoader :
"""
This class is intended to load a plugin and make it available and assess the quality of the developed plugin
"""
def __init__(self,**_args):
"""
:path location of the plugin (should be a single file)
:_names of functions to load
"""
_names = _args['names'] if 'names' in _args else None
path = _args['path'] if 'path' in _args else None
self._names = _names if type(_names) == list else [_names]
# _names = _args['names'] if 'names' in _args else None
# path = _args['path'] if 'path' in _args else None
# self._names = _names if type(_names) == list else [_names]
self._modules = {}
self._names = []
if path and os.path.exists(path) and _names:
for _name in self._names :
spec = importlib.util.spec_from_file_location('private', path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) #--loads it into sys.modules
if hasattr(module,_name) :
if self.isplugin(module,_name) :
self._modules[_name] = getattr(module,_name)
else:
print ([f'Found {_name}', 'not plugin'])
else:
#
# @TODO: We should log this somewhere some how
print (['skipping ',_name, hasattr(module,_name)])
pass
else:
#
# Initialization is empty
self._names = []
self._registry = _args['registry']
pass
def set(self,_pointer) :
def load (self,**_args):
self._modules = {}
self._names = []
path = _args ['path']
if os.path.exists(path) :
_alias = path.split(os.sep)[-1]
spec = importlib.util.spec_from_file_location(_alias, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) #--loads it into sys.modules
for _name in dir(module) :
if self.isplugin(module,_name) :
self._module[_name] = getattr(module,_name)
# self._names [_name]
def format (self,**_args):
uri = _args['alias'],_args['name']
# def set(self,_pointer) :
def set(self,_key) :
"""
This function will set a pointer to the list of modules to be called
This should be used within the context of using the framework as a library
"""
_name = _pointer.__name__
if type(_key).__name__ == 'function':
#
# The pointer is in the code provided by the user and loaded in memory
#
_pointer = _key
_key = 'inline@'+_key.__name__
# self._names.append(_key.__name__)
else:
_pointer = self._registry.get(key=_key)
if _pointer :
self._modules[_key] = _pointer
self._names.append(_key)
self._modules[_name] = _pointer
self._names.append(_name)
def isplugin(self,module,name):
"""
This function determines if a module is a recognized plugin
@ -107,12 +117,31 @@ class PluginLoader :
_n = len(self._names)
return len(set(self._modules.keys()) & set (self._names)) / _n
def apply(self,_data):
def apply(self,_data,_logger=[]):
_input= {}
for _name in self._modules :
_pointer = self._modules[_name]
#
# @TODO: add exception handling
_data = _pointer(_data)
try:
_input = {'action':'plugin','object':_name,'input':{'status':'PASS'}}
_pointer = self._modules[_name]
if type(_data) == list :
_data = pd.DataFrame(_data)
_brow,_bcol = list(_data.shape)
#
# @TODO: add exception handling
_data = _pointer(_data)
_input['input']['shape'] = {'rows-dropped':_brow - _data.shape[0]}
except Exception as e:
_input['input']['status'] = 'FAILED'
print (e)
time.sleep(1)
if _logger:
try:
_logger(**_input)
except Exception as e:
pass
return _data
# def apply(self,_data,_name):
# """

@ -11,7 +11,7 @@ BIGQUERY ='bigquery'
FILE = 'file'
ETL = 'etl'
SQLITE = 'sqlite'
SQLITE = 'sqlite3'
SQLITE3= 'sqlite3'
DUCKDB = 'duckdb'
@ -44,7 +44,9 @@ PGSQL = POSTGRESQL
AWS_S3 = 's3'
RABBIT = RABBITMQ
ICEBERG='iceberg'
APACHE_ICEBERG = 'iceberg'
DRILL = 'drill'
APACHE_DRILL = 'drill'
# QLISTENER = 'qlistener'

@ -220,6 +220,8 @@ def init (email,path=REGISTRY_PATH,override=False,_file=REGISTRY_FILE):
def lookup (label):
global DATA
return label in DATA
has = lookup
def get (label='default') :
global DATA
return copy.copy(DATA[label]) if label in DATA else {}

@ -0,0 +1,7 @@
"""
This namespace/package is intended to handle read/writes against data warehouse solutions like :
- apache iceberg
- clickhouse (...)
"""
from . import iceberg, drill

@ -0,0 +1,55 @@
import sqlalchemy
import pandas as pd
from .. sql.common import BaseReader , BaseWriter
import sqlalchemy as sqa
class Drill :
__template = {'host':None,'port':None,'ssl':None,'table':None,'database':None}
def __init__(self,**_args):
self._host = _args['host'] if 'host' in _args else 'localhost'
self._port = _args['port'] if 'port' in _args else self.get_default_port()
self._ssl = False if 'ssl' not in _args else _args['ssl']
self._table = _args['table'] if 'table' in _args else None
if self._table and '.' in self._table :
_seg = self._table.split('.')
if len(_seg) > 2 :
self._schema,self._database = _seg[:2]
else:
self._database=_args['database']
self._schema = self._database.split('.')[0]
def _get_uri(self,**_args):
return f'drill+sadrill://{self._host}:{self._port}/{self._database}?use_ssl={self._ssl}'
def get_provider(self):
return "drill+sadrill"
def get_default_port(self):
return "8047"
def meta(self,**_args):
_table = _args['table'] if 'table' in _args else self._table
if '.' in _table :
_schema = _table.split('.')[:2]
_schema = '.'.join(_schema)
_table = _table.split('.')[-1]
else:
_schema = self._schema
# _sql = f"select COLUMN_NAME AS name, CASE WHEN DATA_TYPE ='CHARACTER VARYING' THEN 'CHAR ( 125 )' ELSE DATA_TYPE END AS type from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{_schema}' and TABLE_NAME='{_table}'"
_sql = f"select COLUMN_NAME AS name, CASE WHEN DATA_TYPE ='CHARACTER VARYING' THEN 'CHAR ( '||COLUMN_SIZE||' )' ELSE DATA_TYPE END AS type from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA='{_schema}' and TABLE_NAME='{_table}'"
try:
_df = pd.read_sql(_sql,self._engine)
return _df.to_dict(orient='records')
except Exception as e:
print (e)
pass
return []
class Reader (Drill,BaseReader) :
def __init__(self,**_args):
super().__init__(**_args)
self._chunksize = 0 if 'chunksize' not in _args else _args['chunksize']
self._engine= sqa.create_engine(self._get_uri(),future=True)
class Writer(Drill,BaseWriter):
def __init__(self,**_args):
super().__init__(self,**_args)

@ -0,0 +1,151 @@
"""
dependency:
- spark and SPARK_HOME environment variable must be set
NOTE:
When using streaming option, insure that it is inline with default (1000 rows) or increase it in spark-defaults.conf
"""
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.types import *
from pyspark.sql.functions import col, to_date, to_timestamp
import copy
class Iceberg :
def __init__(self,**_args):
"""
providing catalog meta information (you must get this from apache iceberg)
"""
#
# Turning off logging (it's annoying & un-professional)
#
# _spconf = SparkContext()
# _spconf.setLogLevel("ERROR")
#
# @TODO:
# Make arrangements for additional configuration elements
#
self._session = SparkSession.builder.appName("data-transport").getOrCreate()
self._session.conf.set("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS")
# self._session.sparkContext.setLogLevel("ERROR")
self._catalog = self._session.catalog
self._table = _args['table'] if 'table' in _args else None
if 'catalog' in _args :
#
# Let us set the default catalog
self._catalog.setCurrentCatalog(_args['catalog'])
else:
# No current catalog has been set ...
pass
if 'database' in _args :
self._database = _args['database']
self._catalog.setCurrentDatabase(self._database)
else:
#
# Should we set the default as the first one if available ?
#
pass
self._catalogName = self._catalog.currentCatalog()
self._databaseName = self._catalog.currentDatabase()
def meta (self,**_args) :
"""
This function should return the schema of a table (only)
"""
_schema = []
try:
_table = _args['table'] if 'table' in _args else self._table
_tableName = self._getPrefix(**_args) + f".{_table}"
_tmp = self._session.table(_tableName).schema
_schema = _tmp.jsonValue()['fields']
for _item in _schema :
del _item['nullable'],_item['metadata']
except Exception as e:
pass
return _schema
def _getPrefix (self,**_args):
_catName = self._catalogName if 'catalog' not in _args else _args['catalog']
_datName = self._databaseName if 'database' not in _args else _args['database']
return '.'.join([_catName,_datName])
def apply(self,_query):
"""
sql query/command to run against apache iceberg
"""
return self._session.sql(_query)
def has (self,**_args):
try:
_prefix = self._getPrefix(**_args)
if _prefix.endswith('.') :
return False
return _args['table'] in [_item.name for _item in self._catalog.listTables(_prefix)]
except Exception as e:
print (e)
return False
def close(self):
self._session.stop()
class Reader(Iceberg) :
def __init__(self,**_args):
super().__init__(**_args)
def read(self,**_args):
_table = self._table
_prefix = self._getPrefix(**_args)
if 'table' in _args or _table:
_table = _args['table'] if 'table' in _args else _table
_table = _prefix + f'.{_table}'
return self._session.table(_table).toPandas()
else:
sql = _args['sql']
return self._session.sql(sql).toPandas()
pass
class Writer (Iceberg):
"""
Writing data to an Apache Iceberg data warehouse (using pyspark)
"""
def __init__(self,**_args):
super().__init__(**_args)
self._mode = 'append' if 'mode' not in _args else _args['mode']
self._table = None if 'table' not in _args else _args['table']
def format (self,_schema) :
_iceSchema = StructType([])
_map = {'integer':IntegerType(),'float':DoubleType(),'double':DoubleType(),'date':DateType(),
'timestamp':TimestampType(),'datetime':TimestampType(),'string':StringType(),'varchar':StringType()}
for _item in _schema :
_name = _item['name']
_type = _item['type'].lower()
if _type not in _map :
_iceType = StringType()
else:
_iceType = _map[_type]
_iceSchema.add (StructField(_name,_iceType,True))
return _iceSchema if len(_iceSchema) else []
def write(self,_data,**_args):
_prefix = self._getPrefix(**_args)
if 'table' not in _args and not self._table :
raise Exception (f"Table Name should be specified for catalog/database {_prefix}")
_schema = self.format(_args['schema']) if 'schema' in _args else []
if not _schema :
rdd = self._session.createDataFrame(_data,verifySchema=False)
else :
rdd = self._session.createDataFrame(_data,schema=_schema,verifySchema=True)
_mode = self._mode if 'mode' not in _args else _args['mode']
_table = self._table if 'table' not in _args else _args['table']
# print (_data.shape,_mode,_table)
if not self._session.catalog.tableExists(_table):
# # @TODO:
# # add partitioning information here
rdd.writeTo(_table).using('iceberg').create()
# # _mode = 'overwrite'
# # rdd.write.format('iceberg').mode(_mode).saveAsTable(_table)
else:
# rdd.writeTo(_table).append()
# # _table = f'{_prefix}.{_table}'
rdd.coalesce(10).write.format('iceberg').mode('append').save(_table)
Loading…
Cancel
Save