From 14a551e57b53e9154ceda3ce16985d76249704f4 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Thu, 3 Mar 2022 16:08:24 -0600 Subject: [PATCH] Bug fix: sqlalchemy facilities added --- transport/__init__.py | 39 ++++++++++-- transport/sql.py | 135 +++++++++++++++++++++++++++++------------- 2 files changed, 130 insertions(+), 44 deletions(-) diff --git a/transport/__init__.py b/transport/__init__.py index 94b01eb..ce5090b 100644 --- a/transport/__init__.py +++ b/transport/__init__.py @@ -26,7 +26,7 @@ import numpy as np import json import importlib import sys - +import sqlalchemy if sys.version_info[0] > 2 : from transport.common import Reader, Writer #, factory from transport import disk @@ -59,8 +59,8 @@ class factory : "postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}}, "redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}}, "bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}}, - "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}}, - "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}}, + "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my}, + "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my}, "mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}}, "couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}}, "netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}} @@ -137,7 +137,38 @@ def instance(**_args): pointer = factory.PROVIDERS[provider]['class'][_id] else: pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter - + # + # Let us try to establish an sqlalchemy wrapper + try: + host = '' + if provider not in ['bigquery','mongodb','couchdb','sqlite'] : + # + # In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery + username = args['username'] if 'username' in args else '' + password = args['password'] if 'password' in args else '' + if username == '' : + account = '' + else: + account = username + ':'+password+'@' + host = args['host'] + if 'port' in args : + host = host+":"+str(args['port']) + + database = args['database'] + elif provider == 'sqlite': + account = '' + host = '' + database = args['path'] if 'path' in args else args['database'] + if provider not in ['mongodb','couchdb','bigquery'] : + uri = ''.join([provider,"://",account,host,'/',database]) + + e = sqlalchemy.create_engine (uri) + args['sqlalchemy'] = e + # + # @TODO: Include handling of bigquery with SQLAlchemy + except Exception as e: + print (e) + return pointer(**args) return None diff --git a/transport/sql.py b/transport/sql.py index 48d7777..9ccccdb 100644 --- a/transport/sql.py +++ b/transport/sql.py @@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI import psycopg2 as pg import mysql.connector as my import sys + +import sqlalchemy if sys.version_info[0] > 2 : from transport.common import Reader, Writer #, factory else: @@ -44,7 +46,8 @@ class SQLRW : _info['dbname'] = _args['db'] if 'db' in _args else _args['database'] self.table = _args['table'] if 'table' in _args else None self.fields = _args['fields'] if 'fields' in _args else [] - # _provider = _args['provider'] + + self._provider = _args['provider'] if 'provider' in _args else None # _info['host'] = 'localhost' if 'host' not in _args else _args['host'] # _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port'] @@ -59,7 +62,7 @@ class SQLRW : if 'username' in _args or 'user' in _args: key = 'username' if 'username' in _args else 'user' _info['user'] = _args[key] - _info['password'] = _args['password'] + _info['password'] = _args['password'] if 'password' in _args else '' # # We need to load the drivers here to see what we are dealing with ... @@ -74,17 +77,29 @@ class SQLRW : _info['database'] = _info['dbname'] _info['securityLevel'] = 0 del _info['dbname'] + if _handler == my : + _info['database'] = _info['dbname'] + del _info['dbname'] + self.conn = _handler.connect(**_info) + self._engine = _args['sqlalchemy'] if 'sqlalchemy' in _args else None def has(self,**_args): found = False try: table = _args['table'] sql = "SELECT * FROM :table LIMIT 1".replace(":table",table) - found = pd.read_sql(sql,self.conn).shape[0] + if self._engine : + _conn = self._engine.connect() + else: + _conn = self.conn + found = pd.read_sql(sql,_conn).shape[0] found = True except Exception as e: pass + finally: + if self._engine : + _conn.close() return found def isready(self): _sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table) @@ -104,7 +119,8 @@ class SQLRW : try: if "select" in _sql.lower() : cursor.close() - return pd.read_sql(_sql,self.conn) + _conn = self._engine.connect() if self._engine else self.conn + return pd.read_sql(_sql,_conn) else: # Executing a command i.e no expected return values ... cursor.execute(_sql) @@ -122,7 +138,8 @@ class SQLRW : pass class SQLReader(SQLRW,Reader) : def __init__(self,**_args): - super().__init__(**_args) + super().__init__(**_args) + def read(self,**_args): if 'sql' in _args : _sql = (_args['sql']) @@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer): # NOTE: Proper data type should be set on the target system if their source is unclear. self._inspect = False if 'inspect' not in _args else _args['inspect'] self._cast = False if 'cast' not in _args else _args['cast'] + def init(self,fields=None): if not fields : try: - self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist() + self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist() finally: pass else: self.fields = fields; - def make(self,fields): - self.fields = fields - - sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"]) + def make(self,**_args): + + if 'fields' in _args : + fields = _args['fields'] + sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"]) + else: + schema = _args['schema'] + N = len(schema) + _map = _args['map'] if 'map' in _args else {} + sql = [] # ["CREATE TABLE ",_args['table'],"("] + for _item in schema : + _type = _item['type'] + if _type in _map : + _type = _map[_type] + sql = sql + [" " .join([_item['name'], ' ',_type])] + sql = ",".join(sql) + sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"] + sql = " ".join(sql) + # sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"]) cursor = self.conn.cursor() try: + cursor.execute(sql) except Exception as e : print (e) + print (sql) pass finally: - cursor.close() + # cursor.close() + self.conn.commit() + pass def write(self,info): """ :param info writes a list of data to a given set of fields @@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer): elif type(info) == dict : _fields = info.keys() elif type(info) == pd.DataFrame : - _fields = info.columns + _fields = info.columns.tolist() # _fields = info.keys() if type(info) == dict else info[0].keys() _fields = list (_fields) @@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer): # # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy # - if type(info) != list : - # - # We are assuming 2 cases i.e dict or pd.DataFrame - info = [info] if type(info) == dict else info.values.tolist() + # if type(info) != list : + # # + # # We are assuming 2 cases i.e dict or pd.DataFrame + # info = [info] if type(info) == dict else info.values.tolist() cursor = self.conn.cursor() try: + _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields) if self._inspect : for _row in info : @@ -223,34 +261,49 @@ class SQLWriter(SQLRW,Writer): pass else: - _fields = ",".join(self.fields) + # _sql = _sql.replace(":fields",_fields) # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields])) # _sql = _sql.replace("(:fields)","") - _sql = _sql.replace(":fields",_fields) - values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields]) - _sql = _sql.replace(":values",values) - if type(info) == pd.DataFrame : - _info = info[self.fields].values.tolist() - elif type(info) == dict : - _info = info.values() - else: - # _info = [] + + # _sql = _sql.replace(":values",values) + # if type(info) == pd.DataFrame : + # _info = info[self.fields].values.tolist() + + # elif type(info) == dict : + # _info = info.values() + # else: + # # _info = [] - _info = pd.DataFrame(info)[self.fields].values.tolist() - # for row in info : - - # if type(row) == dict : - # _info.append( list(row.values())) - cursor.executemany(_sql,_info) + # _info = pd.DataFrame(info)[self.fields].values.tolist() + # _info = pd.DataFrame(info).to_dict(orient='records') + if type(info) == list : + _info = pd.DataFrame(info) + elif type(info) == dict : + _info = pd.DataFrame([info]) + else: + _info = pd.DataFrame(info) + + + if self._engine : + # pd.to_sql(_info,self._engine) + _info.to_sql(self.table,self._engine,if_exists='append',index=False) + else: + _fields = ",".join(self.fields) + _sql = _sql.replace(":fields",_fields) + values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields]) + _sql = _sql.replace(":values",values) + + cursor.executemany(_sql,_info.values.tolist()) + # cursor.commit() # self.conn.commit() except Exception as e: print(e) pass finally: - self.conn.commit() - cursor.close() + self.conn.commit() + # cursor.close() pass def close(self): try: @@ -265,6 +318,7 @@ class BigQuery: self.path = path self.dtypes = _args['dtypes'] if 'dtypes' in _args else None self.table = _args['table'] if 'table' in _args else None + self.client = bq.Client.from_service_account_json(self.path) def meta(self,**_args): """ This function returns meta data for a given table or query with dataset/table properly formatted @@ -272,16 +326,16 @@ class BigQuery: :param sql sql query to be pulled, """ table = _args['table'] - client = bq.Client.from_service_account_json(self.path) - ref = client.dataset(self.dataset).table(table) - return client.get_table(ref).schema + + ref = self.client.dataset(self.dataset).table(table) + return self.client.get_table(ref).schema def has(self,**_args): found = False try: found = self.meta(**_args) is not None except Exception as e: pass - return found + return found class BQReader(BigQuery,Reader) : def __init__(self,**_args): @@ -304,8 +358,9 @@ class BQReader(BigQuery,Reader) : if (':dataset' in SQL or ':DATASET' in SQL) and self.dataset: SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset) _info = {'credentials':self.credentials,'dialect':'standard'} - return pd.read_gbq(SQL,**_info) if SQL else None - # return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None + return pd.read_gbq(SQL,**_info) if SQL else None + # return self.client.query(SQL).to_dataframe() if SQL else None + class BQWriter(BigQuery,Writer): lock = Lock()