parent
							
								
									3fbd68309f
								
							
						
					
					
						commit
						a2988a5972
					
				@ -0,0 +1,126 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					from transport import factory
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from multiprocessing import Process
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					from google.oauth2 import service_account
 | 
				
			||||||
 | 
					import data.maker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from data.params import SYS_ARGS 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					f = open ('config.json')
 | 
				
			||||||
 | 
					PIPELINE = json.loads(f.read())
 | 
				
			||||||
 | 
					f.close()
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The configuration array is now loaded and we will execute the pipe line as follows
 | 
				
			||||||
 | 
					DATASET='combined20190510_deid'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Components :
 | 
				
			||||||
 | 
					        @staticmethod
 | 
				
			||||||
 | 
					        def get(args):
 | 
				
			||||||
 | 
					                SQL = args['sql']
 | 
				
			||||||
 | 
					                if 'condition' in args :
 | 
				
			||||||
 | 
					                        condition = ' '.join([args['condition']['field'],args['condition']['qualifier'],'(',args['condition']['value'],')'])
 | 
				
			||||||
 | 
					                        SQL = " ".join([SQL,'WHERE',condition])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                SQL = SQL.replace(':dataset',args['dataset']) #+ " LIMIT 1000 "
 | 
				
			||||||
 | 
					                return SQL #+ " LIMIT 10000 "
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        @staticmethod
 | 
				
			||||||
 | 
					        def train(args):
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                This function will instanciate a worker that will train given a message that is provided to it
 | 
				
			||||||
 | 
					                This is/will be a separate process that will
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                print (['starting .... ',args['notify'],args['context']] )
 | 
				
			||||||
 | 
					                #SQL = args['sql']
 | 
				
			||||||
 | 
					                #if 'condition' in args :
 | 
				
			||||||
 | 
					                #       condition = ' '.join([args['condition']['field'],args['condition']['qualifier'],'(',args['condition']['value'],')'])
 | 
				
			||||||
 | 
					                #       SQL = " ".join([SQL,'WHERE',condition])
 | 
				
			||||||
 | 
					                print ( args['context'])
 | 
				
			||||||
 | 
					                logger = factory.instance(type='mongo.MongoWriter',args={'dbname':'aou','doc':args['context']})
 | 
				
			||||||
 | 
					                log_folder = os.sep.join(["logs",args['context']])
 | 
				
			||||||
 | 
					                _args = {"batch_size":2000,"logs":log_folder,"context":args['context'],"max_epochs":250,"num_gpus":2,"column":args['columns'],"id":"person_id","logger":logger}
 | 
				
			||||||
 | 
					                os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
 | 
				
			||||||
 | 
					                #SQL = SQL.replace(':dataset',args['dataset']) #+ " LIMIT 1000 "
 | 
				
			||||||
 | 
					                SQL = Components.get(args)
 | 
				
			||||||
 | 
					                if 'limit' in args :
 | 
				
			||||||
 | 
					                        SQL = ' '.join([SQL,'limit',args['limit'] ])
 | 
				
			||||||
 | 
					                _args['max_epochs'] = 250 if 'max_epochs' not in args else args['max_epochs']
 | 
				
			||||||
 | 
					                credentials = service_account.Credentials.from_service_account_file('/home/steve/dev/aou/accounts/curation-prod.json')
 | 
				
			||||||
 | 
					                _args['data'] = pd.read_gbq(SQL,credentials=credentials,dialect='standard')
 | 
				
			||||||
 | 
					                #_args['data'] = _args['data'].astype(object)
 | 
				
			||||||
 | 
					                _args['num_gpu'] = int(args['num_gpu']) if 'num_gpu' in args else 1 
 | 
				
			||||||
 | 
					                data.maker.train(**_args) 
 | 
				
			||||||
 | 
					        @staticmethod
 | 
				
			||||||
 | 
					        def generate(args):
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                This function will generate data and store it to a given,
 | 
				
			||||||
 | 
					                """
 | 
				
			||||||
 | 
					                logger = factory.instance(type='mongo.MongoWriter',args={'dbname':'aou','doc':args['context']})
 | 
				
			||||||
 | 
					                log_folder = os.sep.join(["logs",args['context']])
 | 
				
			||||||
 | 
					                _args = {"batch_size":2000,"logs":log_folder,"context":args['context'],"max_epochs":250,"num_gpus":2,"column":args['columns'],"id":"person_id","logger":logger}
 | 
				
			||||||
 | 
					                os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
 | 
				
			||||||
 | 
					                SQL = Components.get(args) 
 | 
				
			||||||
 | 
					                if 'limit' in args :
 | 
				
			||||||
 | 
					                        SQL = " ".join([SQL ,'limit', args['limit'] ])
 | 
				
			||||||
 | 
					                credentials = service_account.Credentials.from_service_account_file('/home/steve/dev/aou/accounts/curation-prod.json')
 | 
				
			||||||
 | 
					                _args['data'] = pd.read_gbq(SQL,credentials=credentials,dialect='standard').fillna('')
 | 
				
			||||||
 | 
					                #_args['data'] = _args['data'].astype(object)
 | 
				
			||||||
 | 
					                _args['num_gpu'] = int(args['num_gpu']) if 'num_gpu' in args else 1 
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					                _args['max_epochs'] = 250 if 'max_epochs' not in args else args['max_epochs']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                _args['no_value'] = args['no_value'] if 'no_value' in args else '' 
 | 
				
			||||||
 | 
					                #credentials = service_account.Credentials.from_service_account_file('/home/steve/dev/aou/accounts/curation-prod.json')
 | 
				
			||||||
 | 
					                #_args['data'] = pd.read_gbq(SQL,credentials=credentials,dialect='standard')
 | 
				
			||||||
 | 
					                #_args['data'] = _args['data'].astype(object)
 | 
				
			||||||
 | 
					                _dc = data.maker.generate(**_args) 
 | 
				
			||||||
 | 
					                #
 | 
				
			||||||
 | 
					                # We need to post the generate the data in order to :
 | 
				
			||||||
 | 
					                #       1. compare immediately
 | 
				
			||||||
 | 
					                #       2. synthetic copy
 | 
				
			||||||
 | 
					                #
 | 
				
			||||||
 | 
					                cols = _dc.columns.tolist()
 | 
				
			||||||
 | 
					                print (args['columns']) 
 | 
				
			||||||
 | 
					                data_comp = _args['data'][args['columns']].join(_dc[args['columns']],rsuffix='_io')     #-- will be used for comparison (store this in big query)
 | 
				
			||||||
 | 
					                base_cols = list(set(_args['data'].columns) - set(args['columns']))     #-- rebuilt the dataset (and store it)
 | 
				
			||||||
 | 
					                print (_args['data'].shape) 
 | 
				
			||||||
 | 
					                print (_args['data'].shape) 
 | 
				
			||||||
 | 
					                for name in cols :
 | 
				
			||||||
 | 
					                        _args['data'][name] = _dc[name]
 | 
				
			||||||
 | 
					                        # filename = os.sep.join([log_folder,'output',name+'.csv'])
 | 
				
			||||||
 | 
					                        # data_comp[[name]].to_csv(filename,index=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                #
 | 
				
			||||||
 | 
					                #-- Let us store all of this into bigquery
 | 
				
			||||||
 | 
					                prefix = args['notify']+'.'+_args['context']
 | 
				
			||||||
 | 
					                table = '_'.join([prefix,'compare','io'])
 | 
				
			||||||
 | 
					                data_comp.to_gbq(if_exists='replace',destination_table=table,credentials=credentials,chunksize=50000)           
 | 
				
			||||||
 | 
					                _args['data'].to_gbq(if_exists='replace',destination_table=table.replace('compare','full'),credentials=credentials,chunksize=50000)
 | 
				
			||||||
 | 
					                data_comp.to_csv(os.sep.join([log_folder,table+'.csv']),index=False)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__' :
 | 
				
			||||||
 | 
					        index = int(SYS_ARGS['index'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        args =  (PIPELINE[index])
 | 
				
			||||||
 | 
					        #if 'limit' in SYS_ARGS :
 | 
				
			||||||
 | 
					        #       args['limit'] = SYS_ARGS['limit']
 | 
				
			||||||
 | 
					        #args['dataset'] = 'combined20190510'
 | 
				
			||||||
 | 
					        SYS_ARGS['dataset'] = 'combined20190510_deid' if 'dataset' not in SYS_ARGS else SYS_ARGS['dataset']
 | 
				
			||||||
 | 
					        #if 'max_epochs' in SYS_ARGS :
 | 
				
			||||||
 | 
					        #       args['max_epochs'] = SYS_ARGS['max_epochs']
 | 
				
			||||||
 | 
					        args = dict(args,**SYS_ARGS)
 | 
				
			||||||
 | 
					        if 'generate' in SYS_ARGS :
 | 
				
			||||||
 | 
					                Components.generate(args)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                Components.train(args)
 | 
				
			||||||
 | 
					#for args in PIPELINE :
 | 
				
			||||||
 | 
					        #args['dataset'] = 'combined20190510'
 | 
				
			||||||
 | 
					        #process = Process(target=Components.train,args=(args,))
 | 
				
			||||||
 | 
					        #process.name = args['context']
 | 
				
			||||||
 | 
					        #process.start()
 | 
				
			||||||
 | 
					#       Components.train(args)
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue