diff --git a/transport/sql.py b/transport/sql.py index 32fc356..47ca33a 100644 --- a/transport/sql.py +++ b/transport/sql.py @@ -20,7 +20,7 @@ import json # from threading import Lock import pandas as pd - +import numpy as np class SQLRW : PROVIDERS = {"postgresql":"5432","redshift":"5432","mysql":"3306","mariadb":"3306"} DRIVERS = {"postgresql":pg,"redshift":pg,"mysql":my,"mariadb":my} @@ -95,6 +95,12 @@ class SQLReader(SQLRW,Reader) : class SQLWriter(SQLRW,Writer): def __init__(self,**_args) : super().__init__(**_args) + # + # In the advent that data typing is difficult to determine we can inspect and perform a default case + # This slows down the process but improves reliability of the data + # NOTE: Proper data type should be set on the target system if their source is unclear. + self._inspect = False if 'inspect' not in _args else _args['inspect'] + self._cast = False if 'cast' not in _args else _args['cast'] def init(self,fields): if not fields : try: @@ -118,6 +124,8 @@ class SQLWriter(SQLRW,Writer): """ :param info writes a list of data to a given set of fields """ + # inspect = False if 'inspect' not in _args else _args['inspect'] + # cast = False if 'cast' not in _args else _args['cast'] if not self.fields : _fields = info.keys() if type(info) == dict else info[0].keys() _fields = list (_fields) @@ -127,14 +135,31 @@ class SQLWriter(SQLRW,Writer): info = [info] cursor = self.conn.cursor() try: - - _fields = ",".join(self.fields) - _sql = "INSERT INTO :table (:fields) values (:values)".replace(":table",self.table).replace(":fields",_fields) - _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields])) - - # for row in info : - # values = ["'".join(["",value,""]) if not str(value).isnumeric() else value for value in row.values()] - cursor.executemany(_sql,info) + _sql = "INSERT INTO :table (:fields) values (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields) + if self._inspect : + for _row in info : + fields = list(_row.keys()) + if self._cast == False : + values = ",".join(_row.values()) + else: + values = "'"+"','".join([str(value) for value in _row.values()])+"'" + + # values = [ "".join(["'",str(_row[key]),"'"]) if np.nan(_row[key]).isnumeric() else str(_row[key]) for key in _row] + # print (values) + query = _sql.replace(":fields",",".join(fields)).replace(":values",values) + + cursor.execute(query) + + + pass + else: + _fields = ",".join(self.fields) + _sql = _sql.replace(":fields",_fields) + _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields])) + + # for row in info : + # values = ["'".join(["",value,""]) if not str(value).isnumeric() else value for value in row.values()] + cursor.executemany(_sql,info) self.conn.commit() except Exception as e: print (e)