@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
import psycopg2 as pg
import mysql . connector as my
import sys
import sqlalchemy
if sys . version_info [ 0 ] > 2 :
from transport . common import Reader , Writer #, factory
else :
@ -44,7 +46,8 @@ class SQLRW :
_info [ ' dbname ' ] = _args [ ' db ' ] if ' db ' in _args else _args [ ' database ' ]
self . table = _args [ ' table ' ] if ' table ' in _args else None
self . fields = _args [ ' fields ' ] if ' fields ' in _args else [ ]
# _provider = _args['provider']
self . _provider = _args [ ' provider ' ] if ' provider ' in _args else None
# _info['host'] = 'localhost' if 'host' not in _args else _args['host']
# _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
@ -59,7 +62,7 @@ class SQLRW :
if ' username ' in _args or ' user ' in _args :
key = ' username ' if ' username ' in _args else ' user '
_info [ ' user ' ] = _args [ key ]
_info [ ' password ' ] = _args [ ' password ' ]
_info [ ' password ' ] = _args [ ' password ' ] if ' password ' in _args else ' '
#
# We need to load the drivers here to see what we are dealing with ...
@ -74,17 +77,29 @@ class SQLRW :
_info [ ' database ' ] = _info [ ' dbname ' ]
_info [ ' securityLevel ' ] = 0
del _info [ ' dbname ' ]
if _handler == my :
_info [ ' database ' ] = _info [ ' dbname ' ]
del _info [ ' dbname ' ]
self . conn = _handler . connect ( * * _info )
self . _engine = _args [ ' sqlalchemy ' ] if ' sqlalchemy ' in _args else None
def has ( self , * * _args ) :
found = False
try :
table = _args [ ' table ' ]
sql = " SELECT * FROM :table LIMIT 1 " . replace ( " :table " , table )
found = pd . read_sql ( sql , self . conn ) . shape [ 0 ]
if self . _engine :
_conn = self . _engine . connect ( )
else :
_conn = self . conn
found = pd . read_sql ( sql , _conn ) . shape [ 0 ]
found = True
except Exception as e :
pass
finally :
if self . _engine :
_conn . close ( )
return found
def isready ( self ) :
_sql = " SELECT * FROM :table LIMIT 1 " . replace ( " :table " , self . table )
@ -104,7 +119,8 @@ class SQLRW :
try :
if " select " in _sql . lower ( ) :
cursor . close ( )
return pd . read_sql ( _sql , self . conn )
_conn = self . _engine . connect ( ) if self . _engine else self . conn
return pd . read_sql ( _sql , _conn )
else :
# Executing a command i.e no expected return values ...
cursor . execute ( _sql )
@ -122,7 +138,8 @@ class SQLRW :
pass
class SQLReader ( SQLRW , Reader ) :
def __init__ ( self , * * _args ) :
super ( ) . __init__ ( * * _args )
super ( ) . __init__ ( * * _args )
def read ( self , * * _args ) :
if ' sql ' in _args :
_sql = ( _args [ ' sql ' ] )
@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
# NOTE: Proper data type should be set on the target system if their source is unclear.
self . _inspect = False if ' inspect ' not in _args else _args [ ' inspect ' ]
self . _cast = False if ' cast ' not in _args else _args [ ' cast ' ]
def init ( self , fields = None ) :
if not fields :
try :
self . fields = pd . read_sql ( " SELECT * FROM :table LIMIT 1 " . replace ( " :table " , self . table ) , self . conn ) . columns . tolist ( )
self . fields = pd . read_sql _query ( " SELECT * FROM :table LIMIT 1 " . replace ( " :table " , self . table ) , self . conn ) . columns . tolist ( )
finally :
pass
else :
self . fields = fields ;
def make ( self , fields ) :
self . fields = fields
sql = " " . join ( [ " CREATE TABLE " , self . table , " ( " , " , " . join ( [ name + ' ' + self . _dtype for name in fields ] ) , " ) " ] )
def make ( self , * * _args ) :
if ' fields ' in _args :
fields = _args [ ' fields ' ]
sql = " " . join ( [ " CREATE TABLE " , self . table , " ( " , " , " . join ( [ name + ' ' + self . _dtype for name in fields ] ) , " ) " ] )
else :
schema = _args [ ' schema ' ]
N = len ( schema )
_map = _args [ ' map ' ] if ' map ' in _args else { }
sql = [ ] # ["CREATE TABLE ",_args['table'],"("]
for _item in schema :
_type = _item [ ' type ' ]
if _type in _map :
_type = _map [ _type ]
sql = sql + [ " " . join ( [ _item [ ' name ' ] , ' ' , _type ] ) ]
sql = " , " . join ( sql )
sql = [ " CREATE TABLE " , _args [ ' table ' ] , " ( " , sql , " ) " ]
sql = " " . join ( sql )
# sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
cursor = self . conn . cursor ( )
try :
cursor . execute ( sql )
except Exception as e :
print ( e )
print ( sql )
pass
finally :
cursor . close ( )
# cursor.close()
self . conn . commit ( )
pass
def write ( self , info ) :
"""
: param info writes a list of data to a given set of fields
@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
elif type ( info ) == dict :
_fields = info . keys ( )
elif type ( info ) == pd . DataFrame :
_fields = info . columns
_fields = info . columns . tolist ( )
# _fields = info.keys() if type(info) == dict else info[0].keys()
_fields = list ( _fields )
@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
#
# @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
#
if type ( info ) != list :
#
# We are assuming 2 cases i.e dict or pd.DataFrame
info = [ info ] if type ( info ) == dict else info . values . tolist ( )
# if type(info) != list :
# #
# # We are assuming 2 cases i.e dict or pd.DataFrame
# info = [info] if type(info) == dict else info.values.tolist()
cursor = self . conn . cursor ( )
try :
_sql = " INSERT INTO :table (:fields) VALUES (:values) " . replace ( " :table " , self . table ) #.replace(":table",self.table).replace(":fields",_fields)
if self . _inspect :
for _row in info :
@ -223,34 +261,49 @@ class SQLWriter(SQLRW,Writer):
pass
else :
_fields = " , " . join ( self . fields )
# _sql = _sql.replace(":fields",_fields)
# _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
# _sql = _sql.replace("(:fields)","")
_sql = _sql . replace ( " :fields " , _fields )
values = " , " . join ( " ? " * len ( self . fields ) ) if self . _provider == ' netezza ' else " , " . join ( [ " %s " for name in self . fields ] )
_sql = _sql . replace ( " :values " , values )
if type ( info ) == pd . DataFrame :
_info = info [ self . fields ] . values . tolist ( )
elif type ( info ) == dict :
_info = info . values ( )
else :
# _info = []
# _sql = _sql.replace(":values",values )
# if type(info) == pd.DataFrame :
# _info = info[self.fields].values.tolist()
# elif type(info) == dict :
# _info = info.values()
# else :
# # _info = []
_info = pd . DataFrame ( info ) [ self . fields ] . values . tolist ( )
# for row in info :
# if type(row) == dict :
# _info.append( list(row.values()))
cursor . executemany ( _sql , _info )
# _info = pd.DataFrame(info)[self.fields].values.tolist()
# _info = pd.DataFrame(info).to_dict(orient='records')
if type ( info ) == list :
_info = pd . DataFrame ( info )
elif type ( info ) == dict :
_info = pd . DataFrame ( [ info ] )
else :
_info = pd . DataFrame ( info )
if self . _engine :
# pd.to_sql(_info,self._engine)
_info . to_sql ( self . table , self . _engine , if_exists = ' append ' , index = False )
else :
_fields = " , " . join ( self . fields )
_sql = _sql . replace ( " :fields " , _fields )
values = " , " . join ( " ? " * len ( self . fields ) ) if self . _provider == ' netezza ' else " , " . join ( [ " %s " for name in self . fields ] )
_sql = _sql . replace ( " :values " , values )
cursor . executemany ( _sql , _info . values . tolist ( ) )
# cursor.commit()
# self.conn.commit()
except Exception as e :
print ( e )
pass
finally :
self . conn . commit ( )
cursor . close ( )
self . conn . commit ( )
# cursor.close( )
pass
def close ( self ) :
try :
@ -265,6 +318,7 @@ class BigQuery:
self . path = path
self . dtypes = _args [ ' dtypes ' ] if ' dtypes ' in _args else None
self . table = _args [ ' table ' ] if ' table ' in _args else None
self . client = bq . Client . from_service_account_json ( self . path )
def meta ( self , * * _args ) :
"""
This function returns meta data for a given table or query with dataset / table properly formatted
@ -272,16 +326,16 @@ class BigQuery:
: param sql sql query to be pulled ,
"""
table = _args [ ' table ' ]
client = bq . Client . from_service_account_json ( self . path )
ref = client . dataset ( self . dataset ) . table ( table )
return client . get_table ( ref ) . schema
ref = self . client . dataset ( self . dataset ) . table ( table )
return self . client . get_table ( ref ) . schema
def has ( self , * * _args ) :
found = False
try :
found = self . meta ( * * _args ) is not None
except Exception as e :
pass
return found
return found
class BQReader ( BigQuery , Reader ) :
def __init__ ( self , * * _args ) :
@ -304,8 +358,9 @@ class BQReader(BigQuery,Reader) :
if ( ' :dataset ' in SQL or ' :DATASET ' in SQL ) and self . dataset :
SQL = SQL . replace ( ' :dataset ' , self . dataset ) . replace ( ' :DATASET ' , self . dataset )
_info = { ' credentials ' : self . credentials , ' dialect ' : ' standard ' }
return pd . read_gbq ( SQL , * * _info ) if SQL else None
# return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None
return pd . read_gbq ( SQL , * * _info ) if SQL else None
# return self.client.query(SQL).to_dataframe() if SQL else None
class BQWriter ( BigQuery , Writer ) :
lock = Lock ( )