Compare commits

...

2 Commits

@ -104,7 +104,7 @@ def generate (path:Annotated[str,typer.Argument(help="path of the ETL configurat
{ {
"source":{"provider":"http","url":"https://raw.githubusercontent.com/codeforamerica/ohana-api/master/data/sample-csv/addresses.csv"}, "source":{"provider":"http","url":"https://raw.githubusercontent.com/codeforamerica/ohana-api/master/data/sample-csv/addresses.csv"},
"target": "target":
[{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite","database":"sample.db3","table":"addresses"}] [{"provider":"files","path":"addresses.csv","delimiter":","},{"provider":"sqlite3","database":"sample.db3","table":"addresses"}]
} }
] ]
file = open(path,'w') file = open(path,'w')

@ -1,6 +1,6 @@
__app_name__ = 'data-transport' __app_name__ = 'data-transport'
__author__ = 'The Phi Technology' __author__ = 'The Phi Technology'
__version__= '2.4.0' __version__= '2.4.4'
__email__ = "info@the-phi.com" __email__ = "info@the-phi.com"
__license__=f""" __license__=f"""
Copyright 2010 - 2024, Steve L. Nyemba Copyright 2010 - 2024, Steve L. Nyemba

@ -38,11 +38,16 @@ def init():
if _provider_name.startswith('__') or _provider_name == 'common': if _provider_name.startswith('__') or _provider_name == 'common':
continue continue
PROVIDERS[_provider_name] = {'module':getattr(_module,_provider_name),'type':_module.__name__} PROVIDERS[_provider_name] = {'module':getattr(_module,_provider_name),'type':_module.__name__}
def _getauthfile (path) : #
f = open(path) # loading the registry
_object = json.loads(f.read()) if not registry.isloaded() :
f.close() registry.load()
return _object
# def _getauthfile (path) :
# f = open(path)
# _object = json.loads(f.read())
# f.close()
# return _object
def instance (**_args): def instance (**_args):
""" """
This function returns an object of to read or write from a supported database provider/vendor This function returns an object of to read or write from a supported database provider/vendor
@ -52,15 +57,6 @@ def instance (**_args):
kwargs These are arguments that are provider/vendor specific kwargs These are arguments that are provider/vendor specific
""" """
global PROVIDERS global PROVIDERS
# if not registry.isloaded () :
# if ('path' in _args and registry.exists(_args['path'] )) or registry.exists():
# registry.load() if 'path' not in _args else registry.load(_args['path'])
# print ([' GOT IT'])
# if 'label' in _args and registry.isloaded():
# _info = registry.get(_args['label'])
# if _info :
# #
# _args = dict(_args,**_info)
if 'auth_file' in _args: if 'auth_file' in _args:
if os.path.exists(_args['auth_file']) : if os.path.exists(_args['auth_file']) :
@ -87,8 +83,6 @@ def instance (**_args):
else: else:
_info = registry.get() _info = registry.get()
if _info : if _info :
#
# _args = dict(_args,**_info)
_args = dict(_info,**_args) #-- we can override the registry parameters with our own arguments _args = dict(_info,**_args) #-- we can override the registry parameters with our own arguments
if 'provider' in _args and _args['provider'] in PROVIDERS : if 'provider' in _args and _args['provider'] in PROVIDERS :
@ -119,9 +113,27 @@ def instance (**_args):
# for _delegate in _params : # for _delegate in _params :
# loader.set(_delegate) # loader.set(_delegate)
loader = None if 'plugins' not in _args else _args['plugins'] _plugins = None if 'plugins' not in _args else _args['plugins']
return IReader(_agent,loader) if _context == 'read' else IWriter(_agent,loader) # if registry.has('logger') :
# _kwa = registry.get('logger')
# _lmodule = getPROVIDERS[_kwa['provider']]
if ('label' not in _args and registry.has('logger')):
#
# We did not request label called logger, so we are setting up a logger if it is specified in the registry
#
_kwargs = registry.get('logger')
_kwargs['context'] = 'write'
_kwargs['table'] =_module.__name__.split('.')[-1]+'_logs'
# _logger = instance(**_kwargs)
_module = PROVIDERS[_kwargs['provider']]['module']
_logger = getattr(_module,'Writer')
_logger = _logger(**_kwargs)
else:
_logger = None
_datatransport = IReader(_agent,_plugins,_logger) if _context == 'read' else IWriter(_agent,_plugins,_logger)
return _datatransport
else: else:
# #
@ -138,7 +150,14 @@ class get :
if not _args or ('provider' not in _args and 'label' not in _args): if not _args or ('provider' not in _args and 'label' not in _args):
_args['label'] = 'default' _args['label'] = 'default'
_args['context'] = 'read' _args['context'] = 'read'
return instance(**_args) # return instance(**_args)
# _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'})
_handler = instance(**_args)
# _handler.setLogger(get.logger())
return _handler
@staticmethod @staticmethod
def writer(**_args): def writer(**_args):
""" """
@ -147,10 +166,26 @@ class get :
if not _args or ('provider' not in _args and 'label' not in _args): if not _args or ('provider' not in _args and 'label' not in _args):
_args['label'] = 'default' _args['label'] = 'default'
_args['context'] = 'write' _args['context'] = 'write'
return instance(**_args) # _args['logger'] = instance(**{'label':'logger','context':'write','table':'logs'})
_handler = instance(**_args)
#
# Implementing logging with the 'eat-your-own-dog-food' approach
# Using dependency injection to set the logger (problem with imports)
#
# _handler.setLogger(get.logger())
return _handler
@staticmethod
def logger ():
if registry.has('logger') :
_args = registry.get('logger')
_args['context'] = 'write'
return instance(**_args)
return None
@staticmethod @staticmethod
def etl (**_args): def etl (**_args):
if 'source' in _args and 'target' in _args : if 'source' in _args and 'target' in _args :
return IETL(**_args) return IETL(**_args)
else: else:
raise Exception ("Malformed input found, object must have both 'source' and 'target' attributes") raise Exception ("Malformed input found, object must have both 'source' and 'target' attributes")

@ -15,6 +15,7 @@ import time
MAX_CHUNK = 2000000 MAX_CHUNK = 2000000
class BigQuery: class BigQuery:
__template__= {"private_key":None,"dataset":None,"table":None}
def __init__(self,**_args): def __init__(self,**_args):
path = _args['service_key'] if 'service_key' in _args else _args['private_key'] path = _args['service_key'] if 'service_key' in _args else _args['private_key']
self.credentials = service_account.Credentials.from_service_account_file(path) self.credentials = service_account.Credentials.from_service_account_file(path)

@ -26,6 +26,7 @@ class Bricks:
:cluster_path :cluster_path
:table :table
""" """
__template__ = {"host":None,"token":None,"cluster_path":None,"catalog":None,"schema":None}
def __init__(self,**_args): def __init__(self,**_args):
_host = _args['host'] _host = _args['host']
_token= _args['token'] _token= _args['token']

@ -10,6 +10,7 @@ import json
import nextcloud_client as nextcloud import nextcloud_client as nextcloud
class Nextcloud : class Nextcloud :
__template__={"url":None,"token":None,"uid":None,"file":None}
def __init__(self,**_args): def __init__(self,**_args):
pass pass
self._delimiter = None self._delimiter = None

@ -24,6 +24,7 @@ class s3 :
""" """
@TODO: Implement a search function for a file given a bucket?? @TODO: Implement a search function for a file given a bucket??
""" """
__template__={"access_key":None,"secret_key":None,"bucket":None,"file":None,"region":None}
def __init__(self,**args) : def __init__(self,**args) :
""" """
This function will extract a file or set of files from s3 bucket provided This function will extract a file or set of files from s3 bucket provided

@ -8,24 +8,51 @@ NOTE: Plugins are converted to a pipeline, so we apply a pipeline when reading o
from transport.plugins import PluginLoader from transport.plugins import PluginLoader
import transport import transport
from transport import providers from transport import providers
from multiprocessing import Process from multiprocessing import Process, RLock
import time import time
import types import types
from . import registry from . import registry
from datetime import datetime
import pandas as pd
import os
import sys
import itertools
import json
class IO: class IO:
""" """
Base wrapper class for read/write and support for logs Base wrapper class for read/write and support for logs
""" """
def __init__(self,_agent,plugins): def __init__(self,_agent,plugins,_logger=None):
#
# We need to initialize the logger here ...
#
# registry.init()
self._logger = _logger if not type(_agent) in [IReader,IWriter] else _agent._logger #transport.get.writer(label='logger') #if registry.has('logger') else None
# if not _logger and hasattr(_agent,'_logger') :
# self._logger = getattr(_agent,'_logger')
self._agent = _agent self._agent = _agent
_date = _date = str(datetime.now())
self._logTable = 'logs' #'_'.join(['logs',_date[:10]+_date[11:19]]).replace(':','').replace('-','_')
if plugins : if plugins :
self._init_plugins(plugins) self._init_plugins(plugins)
else: else:
self._plugins = None self._plugins = None
def setLogger(self,_logger):
self._logger = _logger
def log (self,**_args):
if self._logger :
_date = str(datetime.now())
_data = dict({'pid':os.getpid(),'date':_date[:10],'time':_date[11:19]},**_args)
for key in _data :
_data[key] = str(_data[key]) if type(_data[key]) not in [list,dict] else json.dumps(_data[key])
self._logger.write(pd.DataFrame([_data])) #,table=self._logTable)
else:
print ([' ********** '])
print (_args)
def _init_plugins(self,_items): def _init_plugins(self,_items):
""" """
This function will load pipelined functions as a plugin loader This function will load pipelined functions as a plugin loader
@ -33,6 +60,7 @@ class IO:
registry.plugins.init() registry.plugins.init()
self._plugins = PluginLoader(registry=registry.plugins) self._plugins = PluginLoader(registry=registry.plugins)
[self._plugins.set(_name) for _name in _items] [self._plugins.set(_name) for _name in _items]
self.log(action='init-plugins',caller='read', input =[_name for _name in _items])
# if 'path' in _args and 'names' in _args : # if 'path' in _args and 'names' in _args :
# self._plugins = PluginLoader(**_args) # self._plugins = PluginLoader(**_args)
# else: # else:
@ -69,38 +97,74 @@ class IReader(IO):
""" """
This is a wrapper for read functionalities This is a wrapper for read functionalities
""" """
def __init__(self,_agent,pipeline=None): def __init__(self,_agent,_plugins=None,_logger=None):
super().__init__(_agent,pipeline) super().__init__(_agent,_plugins,_logger)
def _stream (self,_data ): def _stream (self,_data ):
# self.log(action='streaming',object=self._agent._engine.name, input= type(_data).__name__)
_shape = []
for _segment in _data : for _segment in _data :
_shape.append(list(_segment.shape))
yield self._plugins.apply(_segment,self.log)
self.log(action='streaming',object=self._agent._engine.name, input= {'shape':_shape})
yield self._plugins.apply(_segment)
def read(self,**_args): def read(self,**_args):
if 'plugins' in _args :
self._init_plugins(_args['plugins'])
_data = self._agent.read(**_args)
if self._plugins and self._plugins.ratio() > 0 : if 'plugins' in _args :
if types.GeneratorType == type(_data): self._init_plugins(_args['plugins'])
return self._stream(_data) _data = self._agent.read(**_args)
_objectName = '.'.join([self._agent.__class__.__module__,self._agent.__class__.__name__])
if types.GeneratorType == type(_data):
if self._plugins :
return self._stream(_data)
else: else:
_data = self._plugins.apply(_data) _count = 0
return _data for _segment in _data :
_count += 1
yield _segment
self.log(action='streaming',object=_objectName, input= {'segments':_count})
# return _data
else: else:
self.log(action='read',object=_objectName, input=_data.shape)
if self._plugins :
_logs = []
_data = self._plugins.apply(_data,self.log)
return _data return _data
# if self._plugins and self._plugins.ratio() > 0 :
# if types.GeneratorType == type(_data):
# return self._stream(_data)
# else:
# _data = self._plugins.apply(_data)
# return _data
# else:
# self.log(action='read',object=self._agent._engine.name, input=_data.shape)
# return _data
class IWriter(IO): class IWriter(IO):
def __init__(self,_agent,pipeline=None): lock = RLock()
super().__init__(_agent,pipeline) def __init__(self,_agent,pipeline=None,_logger=None):
super().__init__(_agent,pipeline,_logger)
def write(self,_data,**_args): def write(self,_data,**_args):
if 'plugins' in _args : if 'plugins' in _args :
self._init_plugins(_args['plugins']) self._init_plugins(_args['plugins'])
if self._plugins and self._plugins.ratio() > 0 : if self._plugins and self._plugins.ratio() > 0 :
_data = self._plugins.apply(_data) _logs = []
_data = self._plugins.apply(_data,_logs,self.log)
self._agent.write(_data,**_args) # [self.log(**_item) for _item in _logs]
try:
# IWriter.lock.acquire()
self._agent.write(_data,**_args)
finally:
# IWriter.lock.release()
pass
# #
# The ETL object in its simplest form is an aggregation of read/write objects # The ETL object in its simplest form is an aggregation of read/write objects
@ -111,8 +175,13 @@ class IETL(IReader) :
This class performs an ETL operation by ineriting a read and adding writes as pipeline functions This class performs an ETL operation by ineriting a read and adding writes as pipeline functions
""" """
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(transport.get.reader(**_args['source'])) _source = _args['source']
_plugins = _source['plugins'] if 'plugins' in _source else None
# super().__init__(transport.get.reader(**_args['source']))
super().__init__(transport.get.reader(**_source),_plugins)
# _logger =
if 'target' in _args: if 'target' in _args:
self._targets = _args['target'] if type(_args['target']) == list else [_args['target']] self._targets = _args['target'] if type(_args['target']) == list else [_args['target']]
else: else:
@ -121,25 +190,25 @@ class IETL(IReader) :
# #
# If the parent is already multiprocessing # If the parent is already multiprocessing
self._hasParentProcess = False if 'hasParentProcess' not in _args else _args['hasParentProcess'] self._hasParentProcess = False if 'hasParentProcess' not in _args else _args['hasParentProcess']
def run(self) : # def run(self) :
""" # """
We should apply the etl here, if we are in multiprocessing mode # We should apply the etl here, if we are in multiprocessing mode
""" # """
_data = super().read() # return self.read()
for _kwargs in self._targets : def run(self,**_args):
self.post(_data,**_kwargs)
def read(self,**_args):
_data = super().read(**_args) _data = super().read(**_args)
self._targets = [transport.get.writer(**_kwargs) for _kwargs in self._targets]
if types.GeneratorType == type(_data): if types.GeneratorType == type(_data):
_index = 0
for _segment in _data : for _segment in _data :
for _kwars in self._targets : _index += 1
self.post(_segment,**_kwargs) for _writer in self._targets :
self.post(_segment,writer=_writer,index=_index)
time.sleep(1)
else: else:
for _writer in self._targets :
for _kwargs in self._targets : self.post(_data,writer=_writer)
self.post(_data,**_kwargs)
return _data return _data
# return _data # return _data
@ -148,6 +217,19 @@ class IETL(IReader) :
This function returns an instance of a process that will perform the write operation This function returns an instance of a process that will perform the write operation
:_args parameters associated with writer object :_args parameters associated with writer object
""" """
writer = transport.get.writer(**_args) #writer = transport.get.writer(**_args)
writer.write(_data)
writer.close() try:
_action = 'post'
_shape = dict(zip(['rows','columns'],_data.shape))
_index = _args['index'] if 'index' in _args else 0
writer = _args['writer']
writer.write(_data)
except Exception as e:
_action = 'post-error'
print (e)
pass
self.log(action=_action,object=writer._agent.__module__, input= {'shape':_shape,'segment':_index})

@ -19,6 +19,7 @@ class Couch:
@param doc user id involved @param doc user id involved
@param dbname database name (target) @param dbname database name (target)
""" """
__template__={"url":None,"doc":None,"dbname":None,"username":None,"password":None}
def __init__(self,**args): def __init__(self,**args):
url = args['url'] if 'url' in args else 'http://localhost:5984' url = args['url'] if 'url' in args else 'http://localhost:5984'
self._id = args['doc'] self._id = args['doc']

@ -25,6 +25,7 @@ class Mongo :
""" """
Basic mongodb functions are captured here Basic mongodb functions are captured here
""" """
__template__={"db":None,"collection":None,"host":None,"port":None,"username":None,"password":None}
def __init__(self,**args): def __init__(self,**args):
""" """
:dbname database name/identifier :dbname database name/identifier

@ -27,14 +27,14 @@ class Reader (File):
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
def _stream(self,path) : def _stream(self,path) :
reader = pd.read_csv(path,delimiter=self.delimiter,chunksize=self._chunksize) reader = pd.read_csv(path,sep=self.delimiter,chunksize=self._chunksize,low_memory=False)
for segment in reader : for segment in reader :
yield segment yield segment
def read(self,**args): def read(self,**args):
_path = self.path if 'path' not in args else args['path'] _path = self.path if 'path' not in args else args['path']
_delimiter = self.delimiter if 'delimiter' not in args else args['delimiter'] _delimiter = self.delimiter if 'delimiter' not in args else args['delimiter']
return pd.read_csv(_path,delimiter=self.delimiter) if not self._chunksize else self._stream(_path) return pd.read_csv(_path,sep=self.delimiter) if not self._chunksize else self._stream(_path)
def stream(self,**args): def stream(self,**args):
raise Exception ("streaming needs to be implemented") raise Exception ("streaming needs to be implemented")
class Writer (File): class Writer (File):

@ -11,6 +11,8 @@ import importlib as IL
import importlib.util import importlib.util
import sys import sys
import os import os
import pandas as pd
class Plugin : class Plugin :
""" """
@ -54,26 +56,7 @@ class PluginLoader :
self._modules = {} self._modules = {}
self._names = [] self._names = []
self._registry = _args['registry'] self._registry = _args['registry']
# if path and os.path.exists(path) and _names:
# for _name in self._names :
# spec = importlib.util.spec_from_file_location('private', path)
# module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(module) #--loads it into sys.modules
# if hasattr(module,_name) :
# if self.isplugin(module,_name) :
# self._modules[_name] = getattr(module,_name)
# else:
# print ([f'Found {_name}', 'not plugin'])
# else:
# #
# # @TODO: We should log this somewhere some how
# print (['skipping ',_name, hasattr(module,_name)])
# pass
# else:
# #
# # Initialization is empty
# self._names = []
pass pass
def load (self,**_args): def load (self,**_args):
self._modules = {} self._modules = {}
@ -84,7 +67,6 @@ class PluginLoader :
spec = importlib.util.spec_from_file_location(_alias, path) spec = importlib.util.spec_from_file_location(_alias, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) #--loads it into sys.modules spec.loader.exec_module(module) #--loads it into sys.modules
# self._names = [_name for _name in dir(module) if type(getattr(module,_name)).__name__ == 'function']
for _name in dir(module) : for _name in dir(module) :
if self.isplugin(module,_name) : if self.isplugin(module,_name) :
self._module[_name] = getattr(module,_name) self._module[_name] = getattr(module,_name)
@ -97,11 +79,6 @@ class PluginLoader :
This function will set a pointer to the list of modules to be called This function will set a pointer to the list of modules to be called
This should be used within the context of using the framework as a library This should be used within the context of using the framework as a library
""" """
# _name = _pointer.__name__ if type(_pointer).__name__ == 'function' else {}
# self._modules[_name] = _pointer
# self._names.append(_name)
_pointer = self._registry.get(key=_key) _pointer = self._registry.get(key=_key)
if _pointer : if _pointer :
self._modules[_key] = _pointer self._modules[_key] = _pointer
@ -137,12 +114,28 @@ class PluginLoader :
_n = len(self._names) _n = len(self._names)
return len(set(self._modules.keys()) & set (self._names)) / _n return len(set(self._modules.keys()) & set (self._names)) / _n
def apply(self,_data): def apply(self,_data,_logger=[]):
_input= {}
for _name in self._modules : for _name in self._modules :
_pointer = self._modules[_name] try:
# _input = {'action':'plugin','object':_name,'input':{'status':'PASS'}}
# @TODO: add exception handling _pointer = self._modules[_name]
_data = _pointer(_data) if type(_data) == list :
_data = pd.DataFrame(_data)
_brow,_bcol = list(_data.shape)
#
# @TODO: add exception handling
_data = _pointer(_data)
_input['input']['shape'] = {'dropped':{'rows':_brow - _data.shape[0],'cols':_bcol-_data.shape[1]}}
except Exception as e:
_input['input']['status'] = 'FAILED'
print (e)
if _logger:
_logger(_input)
return _data return _data
# def apply(self,_data,_name): # def apply(self,_data,_name):
# """ # """

@ -220,6 +220,8 @@ def init (email,path=REGISTRY_PATH,override=False,_file=REGISTRY_FILE):
def lookup (label): def lookup (label):
global DATA global DATA
return label in DATA return label in DATA
has = lookup
def get (label='default') : def get (label='default') :
global DATA global DATA
return copy.copy(DATA[label]) if label in DATA else {} return copy.copy(DATA[label]) if label in DATA else {}

@ -7,7 +7,9 @@ from sqlalchemy import text
import pandas as pd import pandas as pd
class Base: class Base:
__template__={"host":None,"port":1,"database":None,"table":None,"username":None,"password":None}
def __init__(self,**_args): def __init__(self,**_args):
# print ([' ## ',_args]) # print ([' ## ',_args])
self._host = _args['host'] if 'host' in _args else 'localhost' self._host = _args['host'] if 'host' in _args else 'localhost'
@ -122,23 +124,21 @@ class BaseWriter (SQLBase):
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
def write(self,_data,**_args): def write(self,_data,**_args):
if type(_data) == dict : if type(_data) in [list,dict] :
_df = pd.DataFrame(_data)
elif type(_data) == list :
_df = pd.DataFrame(_data) _df = pd.DataFrame(_data)
# elif type(_data) == list :
# _df = pd.DataFrame(_data)
else: else:
_df = _data.copy() _df = _data.copy()
# #
# We are assuming we have a data-frame at this point # We are assuming we have a data-frame at this point
# #
_table = _args['table'] if 'table' in _args else self._table _table = _args['table'] if 'table' in _args else self._table
_mode = {'chunksize':2000000,'if_exists':'append','index':False} _mode = {'if_exists':'append','index':False}
if self._chunksize :
_mode['chunksize'] = self._chunksize
for key in ['if_exists','index','chunksize'] : for key in ['if_exists','index','chunksize'] :
if key in _args : if key in _args :
_mode[key] = _args[key] _mode[key] = _args[key]
# if 'schema' in _args :
# _mode['schema'] = _args['schema']
# if 'if_exists' in _args :
# _mode['if_exists'] = _args['if_exists']
_df.to_sql(_table,self._engine,**_mode) _df.to_sql(_table,self._engine,**_mode)

@ -2,8 +2,9 @@
This module implements the handler for duckdb (in memory or not) This module implements the handler for duckdb (in memory or not)
""" """
from transport.sql.common import Base, BaseReader, BaseWriter from transport.sql.common import Base, BaseReader, BaseWriter
from multiprocessing import RLock
class Duck : class Duck :
lock = RLock()
def __init__(self,**_args): def __init__(self,**_args):
# #
# duckdb with none as database will operate as an in-memory database # duckdb with none as database will operate as an in-memory database
@ -22,3 +23,4 @@ class Writer(Duck,BaseWriter):
def __init__(self,**_args): def __init__(self,**_args):
Duck.__init__(self,**_args) Duck.__init__(self,**_args)
BaseWriter.__init__(self,**_args) BaseWriter.__init__(self,**_args)

@ -1,7 +1,9 @@
import sqlalchemy import sqlalchemy
import pandas as pd import pandas as pd
from transport.sql.common import Base, BaseReader, BaseWriter from transport.sql.common import Base, BaseReader, BaseWriter
class SQLite3 (BaseReader): from multiprocessing import RLock
class SQLite3 :
lock = RLock()
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
if 'path' in _args : if 'path' in _args :
@ -23,3 +25,9 @@ class Reader(SQLite3,BaseReader):
class Writer (SQLite3,BaseWriter): class Writer (SQLite3,BaseWriter):
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
def write(self,_data,**_kwargs):
try:
SQLite3.lock.acquire()
super().write(_data,**_kwargs)
finally:
SQLite3.lock.release()

@ -7,6 +7,7 @@ from transport.sql.common import Base, BaseReader, BaseWriter
class MsSQLServer: class MsSQLServer:
def __init__(self,**_args) : def __init__(self,**_args) :
super().__init__(**_args) super().__init__(**_args)
pass pass

@ -51,7 +51,6 @@ class Iceberg :
_schema = [] _schema = []
try: try:
_tableName = self._getPrefix(**_args) + f".{_args['table']}" _tableName = self._getPrefix(**_args) + f".{_args['table']}"
print (_tableName)
_tmp = self._session.table(_tableName).schema _tmp = self._session.table(_tableName).schema
_schema = _tmp.jsonValue()['fields'] _schema = _tmp.jsonValue()['fields']
for _item in _schema : for _item in _schema :
@ -77,6 +76,8 @@ class Iceberg :
return False return False
def apply(self,sql): def apply(self,sql):
pass pass
def close(self):
self._session.stop()
class Reader(Iceberg) : class Reader(Iceberg) :
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
@ -103,13 +104,20 @@ class Writer (Iceberg):
_prefix = self._getPrefix(**_args) _prefix = self._getPrefix(**_args)
if 'table' not in _args and not self._table : if 'table' not in _args and not self._table :
raise Exception (f"Table Name should be specified for catalog/database {_prefix}") raise Exception (f"Table Name should be specified for catalog/database {_prefix}")
rdd = self._session.createDataFrame(_data)
rdd = self._session.createDataFrame(_data,verifySchema=False)
_mode = self._mode if 'mode' not in _args else _args['mode'] _mode = self._mode if 'mode' not in _args else _args['mode']
_table = self._table if 'table' not in _args else _args['table'] _table = self._table if 'table' not in _args else _args['table']
# print (_data.shape,_mode,_table)
if not self._session.catalog.tableExists(_table):
# # @TODO:
# # add partitioning information here
rdd.writeTo(_table).using('iceberg').create()
if not self.has(table=_table) : # # _mode = 'overwrite'
_mode = 'overwrite' # # rdd.write.format('iceberg').mode(_mode).saveAsTable(_table)
rdd.write.format('iceberg').mode(_mode).saveAsTable(_table)
else: else:
_table = f'{_prefix}.{_table}' # rdd.writeTo(_table).append()
rdd.write.format('iceberg').mode(_mode).save(_table) # # _table = f'{_prefix}.{_table}'
rdd.write.format('iceberg').mode('append').save(_table)

Loading…
Cancel
Save