parent
							
								
									a51be50a86
								
							
						
					
					
						commit
						98a1062a30
					
				@ -0,0 +1,705 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					This code was originally writen by Ziqi Zhang <ziqi.zhang@vanderbilt.edu> in order to generate synthetic data.
 | 
				
			||||||
 | 
					The code is an implementation of a Generative Adversarial Network that uses the Wasserstein Distance (WGAN).
 | 
				
			||||||
 | 
					It is intended to be used in 2 modes (embedded in code or using CLI)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					USAGE :
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The following parameters should be provided in a configuration file (JSON format)
 | 
				
			||||||
 | 
					python data/maker --config <path-to-config-file.json>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CONFIGURATION FILE STRUCTURE :
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																					context																																																																	what it is you are loading (stroke, hypertension, ...)
 | 
				
			||||||
 | 
																																																																					data																																																																																						path of the file to be loaded
 | 
				
			||||||
 | 
																																																																					logs																																																																																						folder to store training model and meta data about learning
 | 
				
			||||||
 | 
																																																																					max_epochs																																												number of iterations in learning 
 | 
				
			||||||
 | 
																																																																					num_gpu																																																																	number of gpus to be used (will still run if the GPUs are not available)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EMBEDDED IN CODE :
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					from tensorflow.contrib.layers import l2_regularizer
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pandas as pd
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from data.params import SYS_ARGS
 | 
				
			||||||
 | 
					from data.bridge import Binary
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import pickle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
 | 
				
			||||||
 | 
					os.environ['CUDA_VISIBLE_DEVICES'] = "0"
 | 
				
			||||||
 | 
					os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# STEPS_PER_EPOCH																																																																															= int(SYS_ARGS['epoch']) if 'epoch' in SYS_ARGS else 256
 | 
				
			||||||
 | 
					# NUM_GPUS																																																																																																																																																																																												= 1 if 'num_gpu' not in SYS_ARGS else int(SYS_ARGS['num_gpu'])
 | 
				
			||||||
 | 
					# BATCHSIZE_PER_GPU																											= 2000
 | 
				
			||||||
 | 
					# TOTAL_BATCHSIZE																																																																															= BATCHSIZE_PER_GPU * NUM_GPUS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class void :
 | 
				
			||||||
 | 
																																																																					pass
 | 
				
			||||||
 | 
					class GNet :
 | 
				
			||||||
 | 
																																																																					def log(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					self.logs = dict(args,**self.logs)
 | 
				
			||||||
 | 
																																																																																						
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					This is the base class of a generative network functions, the details will be implemented in the subclasses.
 | 
				
			||||||
 | 
																																																																					An instance of this class is accessed as follows 
 | 
				
			||||||
 | 
																																																																					object.layers.normalize applies batch normalization or otherwise
 | 
				
			||||||
 | 
																																																																					obect.get.variables																																																																																																											instanciate variables on cpu and return a reference (tensor)
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					def __init__(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					self.layers = void()
 | 
				
			||||||
 | 
																																																																																																																																					self.layers.normalize = self.normalize
 | 
				
			||||||
 | 
																																																																																																																																					self.logs = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
 | 
				
			||||||
 | 
																																																																																																																																					# if self.NUM_GPUS > 1 :
 | 
				
			||||||
 | 
																																																																																																																																					#																																															os.environ['CUDA_VISIBLE_DEVICES'] = "4"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
 | 
				
			||||||
 | 
																																																																																																																																					self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
 | 
				
			||||||
 | 
																																																																																																																																					self.D_STRUCTURE = [self.X_SPACE_SIZE,256,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE*2, self.X_SPACE_SIZE] #-- change 854 to number of diagnosis
 | 
				
			||||||
 | 
																																																																																																																																					# self.NUM_LABELS																																																																															= 8 if 'label' not in args elif len(args['label'].shape) args['label'].shape[1]
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					if 'label' in args and len(args['label'].shape) == 2 :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.NUM_LABELS = args['label'].shape[1]
 | 
				
			||||||
 | 
																																																																																																																																					elif 'label' in args and len(args['label']) == 1 :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.NUM_LABELS = args['label'].shape[0]
 | 
				
			||||||
 | 
																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.NUM_LABELS = None
 | 
				
			||||||
 | 
																																																																																																																																					# self.Z_DIM = 128 #self.X_SPACE_SIZE																																					
 | 
				
			||||||
 | 
																																																																																																																																					self.Z_DIM = 128																#-- used as rows down stream
 | 
				
			||||||
 | 
																																																																																																																																					self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
 | 
				
			||||||
 | 
																																																																																																																																					PROPOSED_BATCH_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
 | 
				
			||||||
 | 
																																																																																																																																					self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
 | 
				
			||||||
 | 
																																																																																																																																					if 'real' in args : 
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					if args['real'].shape[0]																< PROPOSED_BATCH_PER_GPU :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) 
 | 
				
			||||||
 | 
																																																																																																																																					# self.BATCHSIZE_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
 | 
				
			||||||
 | 
																																																																																																																																					self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
 | 
				
			||||||
 | 
																																																																																																																																					self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)																																													
 | 
				
			||||||
 | 
																																																																																																																																					self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
 | 
				
			||||||
 | 
																																																																																																																																					self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100
 | 
				
			||||||
 | 
																																																																																																																																					self.CONTEXT = args['context']
 | 
				
			||||||
 | 
																																																																																																																																					self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None}
 | 
				
			||||||
 | 
																																																																																																																																					self._REAL = args['real'] if 'real' in args else None
 | 
				
			||||||
 | 
																																																																																																																																					self._LABEL = args['label'] if 'label' in args else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					self.get = void()
 | 
				
			||||||
 | 
																																																																																																																																					self.get.variables = self._variable_on_cpu
 | 
				
			||||||
 | 
																																																																																																																																					self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
				
			||||||
 | 
																																																																																																																																					self.logger = args['logger'] if 'logger' in args and args['logger'] else None
 | 
				
			||||||
 | 
																																																																																																																																					self.init_logs(**args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																					def init_logs(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					self.log_dir = args['logs'] if 'logs' in args else 'logs'
 | 
				
			||||||
 | 
																																																																																																																																					self.mkdir(self.log_dir)
 | 
				
			||||||
 | 
																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																					# 
 | 
				
			||||||
 | 
																																																																																																																																					for key in ['train','output'] :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.mkdir(os.sep.join([self.log_dir,key]))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.mkdir (os.sep.join([self.log_dir,key,self.CONTEXT]))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					self.train_dir		= os.sep.join([self.log_dir,'train',self.CONTEXT])																																																																																																																												
 | 
				
			||||||
 | 
																																																																																																																																					self.out_dir = os.sep.join([self.log_dir,'output',self.CONTEXT])
 | 
				
			||||||
 | 
																																																																																																																																					if self.logger :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# We will clear the logs from the data-store 
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					column = self.ATTRIBUTES['synthetic']
 | 
				
			||||||
 | 
																																																																																																																																																																																																					db = self.logger.db
 | 
				
			||||||
 | 
																																																																																																																																																																																																					if db[column].count() > 0 :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					db.backup.insert({'name':column,'logs':list(db[column].find()) })
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					db[column].drop()
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																					def load_meta(self,column):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function is designed to accomodate the uses of the sub-classes outside of a strict dependency model.
 | 
				
			||||||
 | 
																																																																																																																																					Because prediction and training can happen independently
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					# suffix = "-".join(column) if isinstance(column,list)else column
 | 
				
			||||||
 | 
																																																																																																																																					suffix = self.get.suffix()
 | 
				
			||||||
 | 
																																																																																																																																					_name = os.sep.join([self.out_dir,'meta-'+suffix+'.json'])
 | 
				
			||||||
 | 
																																																																																																																																					if os.path.exists(_name) :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					attr = json.loads((open(_name)).read())
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for key in attr :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					value = attr[key]
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					setattr(self,key,value)
 | 
				
			||||||
 | 
																																																																																																																																					self.train_dir		= os.sep.join([self.log_dir,'train',self.CONTEXT])																																																																																																																												
 | 
				
			||||||
 | 
																																																																																																																																					self.out_dir = os.sep.join([self.log_dir,'output',self.CONTEXT])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																					def log_meta(self,**args) :
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					_object = {
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# '_id':'meta',
 | 
				
			||||||
 | 
																																																																																																																																																																																																					'CONTEXT':self.CONTEXT,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					'ATTRIBUTES':self.ATTRIBUTES,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					'Z_DIM':self.Z_DIM,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"X_SPACE_SIZE":self.X_SPACE_SIZE,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"D_STRUCTURE":self.D_STRUCTURE,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"G_STRUCTURE":self.G_STRUCTURE,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"NUM_GPUS":self.NUM_GPUS,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"NUM_LABELS":self.NUM_LABELS,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"MAX_EPOCHS":self.MAX_EPOCHS,
 | 
				
			||||||
 | 
																																																																																																																																																																																																					"ROW_COUNT":self.ROW_COUNT
 | 
				
			||||||
 | 
																																																																																																																																					}
 | 
				
			||||||
 | 
																																																																																																																																					if args and 'key' in args and 'value' in args :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					key = args['key']
 | 
				
			||||||
 | 
																																																																																																																																																																																																					value= args['value']
 | 
				
			||||||
 | 
																																																																																																																																																																																																					object[key] = value
 | 
				
			||||||
 | 
																																																																																																																																					# suffix = "-".join(self.column) if isinstance(self.column,list) else self.column
 | 
				
			||||||
 | 
																																																																																																																																					suffix = self.get.suffix()
 | 
				
			||||||
 | 
																																																																																																																																					_name = os.sep.join([self.out_dir,'meta-'+suffix])
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					f = open(_name+'.json','w')
 | 
				
			||||||
 | 
																																																																																																																																					f.write(json.dumps(_object))
 | 
				
			||||||
 | 
																																																																																																																																					return _object
 | 
				
			||||||
 | 
																																																																					def mkdir (self,path):
 | 
				
			||||||
 | 
																																																																																																																																					if not os.path.exists(path) :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					os.mkdir(path)																																																																																		
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																					def normalize(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function will perform a batch normalization on an network layer
 | 
				
			||||||
 | 
																																																																																																																																					inputs																																																																		input layer of the neural network
 | 
				
			||||||
 | 
																																																																																																																																					name																																																																																						name of the scope the 
 | 
				
			||||||
 | 
																																																																																																																																					labels																																																																		labels (attributes not synthesized) by default None
 | 
				
			||||||
 | 
																																																																																																																																					n_labels																																																																number of labels default None
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					inputs		= args['inputs']
 | 
				
			||||||
 | 
																																																																																																																																					name																						= args['name']
 | 
				
			||||||
 | 
																																																																																																																																					labels		= None if 'labels' not in args else args['labels']
 | 
				
			||||||
 | 
																																																																																																																																					n_labels= None if 'n_labels' not in args else args['n_labels']
 | 
				
			||||||
 | 
																																																																																																																																					shift																					= [0] if self.__class__.__name__.lower() == 'generator' else [1] #-- not sure what this is doing
 | 
				
			||||||
 | 
																																																																																																																																					mean, var																																																															= tf.nn.moments(inputs, shift, keep_dims=True)
 | 
				
			||||||
 | 
																																																																																																																																					shape																																																																																					= inputs.shape[1].value
 | 
				
			||||||
 | 
																																																																																																																																					if labels is not None:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					offset_m																																																																= self.get.variables(shape=[1,shape], name='offset'+name,
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																					initializer=tf.zeros_initializer)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					scale_m = self.get.variables(shape=[n_labels,shape], name='scale'+name,
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																																					initializer=tf.ones_initializer)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					offset		= tf.nn.embedding_lookup(offset_m, labels)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					scale																					= tf.nn.embedding_lookup(scale_m, labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					offset = None
 | 
				
			||||||
 | 
																																																																																																																																																																																																					scale = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					result		= tf.nn.batch_normalization(inputs, mean, var,offset,scale, 1e-8)
 | 
				
			||||||
 | 
																																																																																																																																					return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																					def _variable_on_cpu(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function makes sure variables/tensors are not created on the GPU but rather on the CPU
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					name = args['name']
 | 
				
			||||||
 | 
																																																																																																																																					shape = args['shape']
 | 
				
			||||||
 | 
																																																																																																																																					initializer=None if 'initializer' not in args else args['initializer']
 | 
				
			||||||
 | 
																																																																																																																																					with tf.device('/cpu:0') :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					cpu_var =  tf.compat.v1.get_variable(name,shape,initializer= initializer)
 | 
				
			||||||
 | 
																																																																																																																																					return cpu_var
 | 
				
			||||||
 | 
																																																																					def average_gradients(self,tower_grads):
 | 
				
			||||||
 | 
																																																																																																																																					average_grads = []
 | 
				
			||||||
 | 
																																																																																																																																					for grad_and_vars in zip(*tower_grads):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					grads = []
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for g, _ in grad_and_vars:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					expanded_g = tf.expand_dims(g, 0)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					grads.append(expanded_g)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					grad = tf.concat(axis=0, values=grads)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					grad = tf.reduce_mean(grad, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					v = grad_and_vars[0][1]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					grad_and_var = (grad, v)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					average_grads.append(grad_and_var)
 | 
				
			||||||
 | 
																																																																																																																																					return average_grads																																																																																						
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Generator (GNet):
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					This class is designed to handle generation of candidate datasets for this it will aggregate a discriminator, this allows the generator not to be random
 | 
				
			||||||
 | 
																																																																					
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					def __init__(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					GNet.__init__(self,**args)
 | 
				
			||||||
 | 
																																																																																																																																					self.discriminator = Discriminator(**args)
 | 
				
			||||||
 | 
																																																																					def loss(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					fake																						= args['fake']
 | 
				
			||||||
 | 
																																																																																																																																					label																					= args['label']
 | 
				
			||||||
 | 
																																																																																																																																					y_hat_fake = self.discriminator.network(inputs=fake, label=label)
 | 
				
			||||||
 | 
																																																																																																																																					#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
 | 
				
			||||||
 | 
																																																																																																																																					all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
 | 
				
			||||||
 | 
																																																																																																																																					loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
 | 
				
			||||||
 | 
																																																																																																																																					#tf.add_to_collection('glosses', loss)
 | 
				
			||||||
 | 
																																																																																																																																					tf.compat.v1.add_to_collection('glosses', loss)
 | 
				
			||||||
 | 
																																																																																																																																					return loss, loss																																																																																																																																
 | 
				
			||||||
 | 
																																																																					def load_meta(self, column):
 | 
				
			||||||
 | 
																																																																																																																																					super().load_meta(column)
 | 
				
			||||||
 | 
																																																																																																																																					self.discriminator.load_meta(column)
 | 
				
			||||||
 | 
																																																																					def network(self,**args) :
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function will build the network that will generate the synthetic candidates
 | 
				
			||||||
 | 
																																																																																																																																					:inputs matrix of data that we need
 | 
				
			||||||
 | 
																																																																																																																																					:dim																						dimensions of ...
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					x																																																																																																																															= args['inputs']
 | 
				
			||||||
 | 
																																																																																																																																					tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
 | 
				
			||||||
 | 
																																																																																																																																					label																					= args['label']
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for i, dim in enumerate(self.G_STRUCTURE[:-1]):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					h1 = self.normalize(inputs=tf.matmul(x, kernel),shift=0, name='cbn' + str(i), labels=label, n_labels=self.NUM_LABELS)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					h2 = tf.nn.relu(h1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					x = x + h2
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					tmp_dim = dim
 | 
				
			||||||
 | 
																																																																																																																																																																																																					i = len(self.G_STRUCTURE) - 1
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# This seems to be an extra hidden layer: 
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# It's goal is to map continuous values to discrete values (pre-trained to do this)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, self.G_STRUCTURE[-1]])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					h1 = self.normalize(inputs=tf.matmul(x, kernel), name='cbn' + str(i),
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					labels=label, n_labels=self.NUM_LABELS)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					h2 = tf.nn.tanh(h1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					x = x + h2
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# This seems to be the output layer
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					kernel = self.get.variables(name='W_' + str(i+1), shape=[self.Z_DIM, self.X_SPACE_SIZE])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					bias = self.get.variables(name='b_' + str(i+1), shape=[self.X_SPACE_SIZE])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					x = tf.nn.sigmoid(tf.add(tf.matmul(x, kernel), bias))
 | 
				
			||||||
 | 
																																																																																																																																					return x																																																																																	
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Discriminator(GNet):
 | 
				
			||||||
 | 
																																																																					def __init__(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					GNet.__init__(self,**args)																																																													
 | 
				
			||||||
 | 
																																																																					def network(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function will apply a computational graph on a dataset passed in with the associated labels and the last layer must have a single output (neuron)
 | 
				
			||||||
 | 
																																																																																																																																					:inputs
 | 
				
			||||||
 | 
																																																																																																																																					:label
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					x = args['inputs']
 | 
				
			||||||
 | 
																																																																																																																																					label = args['label']
 | 
				
			||||||
 | 
																																																																																																																																					with tf.compat.v1.variable_scope('D', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for i, dim in enumerate(self.D_STRUCTURE[1:]):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					kernel = self.get.variables(name='W_' + str(i), shape=[self.D_STRUCTURE[i], dim])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					bias = self.get.variables(name='b_' + str(i), shape=[dim])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					# print (["\t",bias,kernel])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					x = tf.nn.relu(tf.add(tf.matmul(x, kernel), bias))
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					x = self.normalize(inputs=x, name='cln' + str(i), shift=1,labels=label, n_labels=self.NUM_LABELS)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					i = len(self.D_STRUCTURE)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					kernel = self.get.variables(name='W_' + str(i), shape=[self.D_STRUCTURE[-1], 1])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					bias = self.get.variables(name='b_' + str(i), shape=[1])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					y = tf.add(tf.matmul(x, kernel), bias)
 | 
				
			||||||
 | 
																																																																																																																																					return y
 | 
				
			||||||
 | 
																																																																					
 | 
				
			||||||
 | 
																																																																					def loss(self,**args) :
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function compute the loss of 
 | 
				
			||||||
 | 
																																																																																																																																					:real
 | 
				
			||||||
 | 
																																																																																																																																					:fake
 | 
				
			||||||
 | 
																																																																																																																																					:label
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					real																						= args['real']
 | 
				
			||||||
 | 
																																																																																																																																					fake																						= args['fake']
 | 
				
			||||||
 | 
																																																																																																																																					label																					= args['label']
 | 
				
			||||||
 | 
																																																																																																																																					epsilon = tf.random.uniform(shape=[self.BATCHSIZE_PER_GPU,1],minval=0,maxval=1)
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					x_hat																																																																																					= real + epsilon * (fake - real)
 | 
				
			||||||
 | 
																																																																																																																																					y_hat_fake																																												= self.network(inputs=fake, label=label)
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					y_hat_real																																												= self.network(inputs=real, label=label)
 | 
				
			||||||
 | 
																																																																																																																																					y_hat																																																																																					= self.network(inputs=x_hat, label=label)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					grad																																																																																						= tf.gradients(y_hat, [x_hat])[0]
 | 
				
			||||||
 | 
																																																																																																																																					slopes																																																																		= tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
 | 
				
			||||||
 | 
																																																																																																																																					gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
 | 
				
			||||||
 | 
																																																																																																																																					#all_regs																																																																= tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
 | 
				
			||||||
 | 
																																																																																																																																					all_regs																																																																= tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
 | 
				
			||||||
 | 
																																																																																																																																					w_distance																																												= -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
 | 
				
			||||||
 | 
																																																																																																																																					loss																																																																																						= w_distance + 10 * gradient_penalty + sum(all_regs)
 | 
				
			||||||
 | 
																																																																																																																																					#tf.add_to_collection('dlosses', loss)
 | 
				
			||||||
 | 
																																																																																																																																					tf.compat.v1.add_to_collection('dlosses', loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					return w_distance, loss																																																																																		
 | 
				
			||||||
 | 
					class Train (GNet):
 | 
				
			||||||
 | 
																																																																					def __init__(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					GNet.__init__(self,**args)
 | 
				
			||||||
 | 
																																																																																																																																					self.generator = Generator(**args)
 | 
				
			||||||
 | 
																																																																																																																																					self.discriminator = Discriminator(**args)
 | 
				
			||||||
 | 
																																																																																																																																					self._REAL = args['real']
 | 
				
			||||||
 | 
																																																																																																																																					self._LABEL= args['label'] if 'label' in args else None
 | 
				
			||||||
 | 
																																																																																																																																					self.column = args['column']
 | 
				
			||||||
 | 
																																																																																																																																					# print ([" *** ",self.BATCHSIZE_PER_GPU])
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					self.meta = self.log_meta()
 | 
				
			||||||
 | 
																																																																																																																																					if(self.logger):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					self.logger.write( self.meta )
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					# self.log (real_shape=list(self._REAL.shape),label_shape = self._LABEL.shape,meta_data=self.meta)
 | 
				
			||||||
 | 
																																																																					def load_meta(self, column):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function will delegate the calls to load meta data to it's dependents
 | 
				
			||||||
 | 
																																																																																																																																					column name
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					super().load_meta(column)
 | 
				
			||||||
 | 
																																																																																																																																					self.generator.load_meta(column)
 | 
				
			||||||
 | 
																																																																																																																																					self.discriminator.load_meta(column)
 | 
				
			||||||
 | 
																																																																					def loss(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function will compute a "tower" loss of the generated candidate against real data
 | 
				
			||||||
 | 
																																																																																																																																					Training will consist in having both generator and discriminators
 | 
				
			||||||
 | 
																																																																																																																																					:scope
 | 
				
			||||||
 | 
																																																																																																																																					:stage
 | 
				
			||||||
 | 
																																																																																																																																					:real
 | 
				
			||||||
 | 
																																																																																																																																					:label
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					scope																					= args['scope']
 | 
				
			||||||
 | 
																																																																																																																																					stage																					= args['stage']
 | 
				
			||||||
 | 
																																																																																																																																					real																						= args['real']
 | 
				
			||||||
 | 
																																																																																																																																					label																					= args['label']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					if label is not None :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					label																					= tf.cast(label, tf.int32)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# @TODO: Ziqi needs to explain what's going on here
 | 
				
			||||||
 | 
																																																																																																																																																																																																					m = [[i] for i in np.arange(self._LABEL.shape[1]-2)]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					label																					= label[:, 1] * len(m) + tf.squeeze(
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					tf.matmul(label[:, 2:], tf.constant(m, dtype=tf.int32))
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					)
 | 
				
			||||||
 | 
																																																																																																																																					# label = label[:,1] * 4 + tf.squeeze( label[:,2]*[[0],[1],[2],[3]] )
 | 
				
			||||||
 | 
																																																																																																																																					z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM])
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					fake = self.generator.network(inputs=z, label=label)
 | 
				
			||||||
 | 
																																																																																																																																					if stage == 'D':
 | 
				
			||||||
 | 
																																																																																																																																																																																																					w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#losses = tf.get_collection('dlosses', scope)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					flag = 'dlosses'
 | 
				
			||||||
 | 
																																																																																																																																																																																																					losses = tf.compat.v1.get_collection('dlosses', scope)
 | 
				
			||||||
 | 
																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					w, loss = self.generator.loss(fake=fake, label=label)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#losses = tf.get_collection('glosses', scope)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					flag = 'glosses'
 | 
				
			||||||
 | 
																																																																																																																																																																																																					losses = tf.compat.v1.get_collection('glosses', scope)
 | 
				
			||||||
 | 
																																																																																																																																					# losses = tf.compat.v1.get_collection(flag, scope)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					total_loss = tf.add_n(losses, name='total_loss')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					return total_loss, w
 | 
				
			||||||
 | 
																																																																					def input_fn(self):
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					This function seems to produce 
 | 
				
			||||||
 | 
																																																																																																																																					"""
 | 
				
			||||||
 | 
																																																																																																																																					features_placeholder = tf.compat.v1.placeholder(shape=self._REAL.shape, dtype=tf.float32)
 | 
				
			||||||
 | 
																																																																																																																																					LABEL_SHAPE = [None,None] if self._LABEL is None else self._LABEL.shape
 | 
				
			||||||
 | 
																																																																																																																																					labels_placeholder = tf.compat.v1.placeholder(shape=LABEL_SHAPE, dtype=tf.float32)
 | 
				
			||||||
 | 
																																																																																																																																					if self._LABEL is not None :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
 | 
				
			||||||
 | 
																																																																																																																																					else :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
 | 
				
			||||||
 | 
																																																																																																																																					# labels_placeholder = None
 | 
				
			||||||
 | 
																																																																																																																																					dataset = dataset.repeat(10000)
 | 
				
			||||||
 | 
																																																																																																																																					dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
 | 
				
			||||||
 | 
																																																																																																																																					dataset = dataset.prefetch(1)
 | 
				
			||||||
 | 
																																																																																																																																					# iterator = dataset.make_initializable_iterator()
 | 
				
			||||||
 | 
																																																																																																																																					iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
 | 
				
			||||||
 | 
																																																																																																																																					return iterator, features_placeholder, labels_placeholder
 | 
				
			||||||
 | 
																																																																					
 | 
				
			||||||
 | 
																																																																					def network(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					stage																					= args['stage']
 | 
				
			||||||
 | 
																																																																																																																																					opt																																																																																																											= args['opt']
 | 
				
			||||||
 | 
																																																																																																																																					tower_grads = []
 | 
				
			||||||
 | 
																																																																																																																																					per_gpu_w																																																															= []
 | 
				
			||||||
 | 
																																																																																																																																					iterator, features_placeholder, labels_placeholder = self.input_fn()
 | 
				
			||||||
 | 
																																																																																																																																					with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for i in range(self.NUM_GPUS):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					with tf.device('/gpu:%d' % i):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					if self._LABEL is not None :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					(real, label) = iterator.get_next()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					real = iterator.get_next()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					label= None
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					loss, w = self.loss(scope=scope, stage=stage, real=real, label=label)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					#tf.get_variable_scope().reuse_variables()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					tf.compat.v1.get_variable_scope().reuse_variables()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					grads = opt.compute_gradients(loss, vars_)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					tower_grads.append(grads)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					per_gpu_w.append(w)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					grads = self.average_gradients(tower_grads)
 | 
				
			||||||
 | 
																																																																																																																																					apply_gradient_op = opt.apply_gradients(grads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																					mean_w = tf.reduce_mean(per_gpu_w)
 | 
				
			||||||
 | 
																																																																																																																																					train_op = apply_gradient_op
 | 
				
			||||||
 | 
																																																																																																																																					return train_op, mean_w, iterator, features_placeholder, labels_placeholder
 | 
				
			||||||
 | 
																																																																					def apply(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					# max_epochs = args['max_epochs'] if 'max_epochs' in args else 10
 | 
				
			||||||
 | 
																																																																																																																																					REAL = self._REAL
 | 
				
			||||||
 | 
																																																																																																																																					LABEL= self._LABEL																																													
 | 
				
			||||||
 | 
																																																																																																																																					if (self.logger):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					pass
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					with tf.device('/cpu:0'):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					opt_d = tf.compat.v1.train.AdamOptimizer(1e-4)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					opt_g = tf.compat.v1.train.AdamOptimizer(1e-4)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# saver = tf.train.Saver()
 | 
				
			||||||
 | 
																																																																																																																																																																																																					saver																					= tf.compat.v1.train.Saver()
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# init																		= tf.global_variables_initializer()
 | 
				
			||||||
 | 
																																																																																																																																																																																																					init																						= tf.compat.v1.global_variables_initializer()
 | 
				
			||||||
 | 
																																																																																																																																																																																																					logs = []
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					sess.run(init)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					sess.run(iterator_d.initializer,
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					feed_dict={features_placeholder_d: REAL})
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					sess.run(iterator_g.initializer,
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					feed_dict={features_placeholder_g: REAL})
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					for epoch in range(1, self.MAX_EPOCHS + 1):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					start_time = time.time()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					w_sum = 0
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					for i in range(self.STEPS_PER_EPOCH):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					for _ in range(2):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					_, w = sess.run([train_d, w_distance])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					w_sum += w
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					sess.run(train_g)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					duration = time.time() - start_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					assert not np.isnan(w_sum), 'Model diverged with loss = NaN'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					format_str = 'epoch: %d, w_distance = %f (%.1f)'
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					# print (dir (w_distance))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					if epoch % self.MAX_EPOCHS == 0:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					suffix = self.get.suffix()
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					_name  = os.sep.join([self.train_dir,suffix])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					# saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					if self.logger :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					row = {"logs":logs} #,"model":pickle.dump(sess)}																																																																																																																																																																																																																																																																																																																																																																																																																																																																
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					self.logger.write(row)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					# @TODO:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					# We should upload the files in the checkpoint 
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					# This would allow the learnt model to be portable to another system
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					tf.compat.v1.reset_default_graph()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Predict(GNet):
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					This class uses synthetic data given a learned model
 | 
				
			||||||
 | 
																																																																					"""
 | 
				
			||||||
 | 
																																																																					def __init__(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					GNet.__init__(self,**args)																																																																																																																												
 | 
				
			||||||
 | 
																																																																																																																																					self.generator = Generator(**args)																																																																																																																												
 | 
				
			||||||
 | 
																																																																																																																																					self.values  = args['values']
 | 
				
			||||||
 | 
																																																																					def load_meta(self, column):
 | 
				
			||||||
 | 
																																																																																																																																					super().load_meta(column)
 | 
				
			||||||
 | 
																																																																																																																																					self.generator.load_meta(column)
 | 
				
			||||||
 | 
																																																																					def apply(self,**args):
 | 
				
			||||||
 | 
																																																																																																																																					# print (self.train_dir)
 | 
				
			||||||
 | 
																																																																																																																																					# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
				
			||||||
 | 
																																																																																																																																					suffix = self.get.suffix()
 | 
				
			||||||
 | 
																																																																																																																																					model_dir = os.sep.join([self.train_dir,suffix+'-'+str(self.MAX_EPOCHS)])
 | 
				
			||||||
 | 
																																																																																																																																					demo = self._LABEL #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo']
 | 
				
			||||||
 | 
																																																																																																																																					tf.compat.v1.reset_default_graph()
 | 
				
			||||||
 | 
																																																																																																																																					#z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM])
 | 
				
			||||||
 | 
																																																																																																																																					z = tf.random.normal(shape=[self._REAL.shape[0], self.Z_DIM])
 | 
				
			||||||
 | 
																																																																																																																																					y = tf.compat.v1.placeholder(shape=[self._REAL.shape[0], self.NUM_LABELS], dtype=tf.int32)
 | 
				
			||||||
 | 
																																																																																																																																					#y = tf.compat.v1.placeholder(shape=[self.BATCHSIZE_PER_GPU, self.NUM_LABELS], dtype=tf.int32)
 | 
				
			||||||
 | 
																																																																																																																																					if self._LABEL is not None :
 | 
				
			||||||
 | 
																																																																																																																																																																																																					ma = [[i] for i in np.arange(self.NUM_LABELS - 2)]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					label = y[:, 1] * len(ma) + tf.squeeze(tf.matmul(y[:, 2:], tf.constant(ma, dtype=tf.int32)))
 | 
				
			||||||
 | 
																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					label = None
 | 
				
			||||||
 | 
																																																																																																																																					fake																						= self.generator.network(inputs=z, label=label)
 | 
				
			||||||
 | 
																																																																																																																																					init																						= tf.compat.v1.global_variables_initializer()
 | 
				
			||||||
 | 
																																																																																																																																					saver																					= tf.compat.v1.train.Saver()
 | 
				
			||||||
 | 
																																																																																																																																					df																																																																																																												= pd.DataFrame()
 | 
				
			||||||
 | 
																																																																																																																																					CANDIDATE_COUNT = 10000
 | 
				
			||||||
 | 
																																																																																																																																					NTH_VALID_CANDIDATE = count = np.random.choice(np.arange(2,60),2)[0]
 | 
				
			||||||
 | 
																																																																																																																																					with tf.compat.v1.Session() as sess:
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# sess.run(init)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					saver.restore(sess, model_dir)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					if self._LABEL is not None :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					labels = np.zeros((self.ROW_COUNT,self.NUM_LABELS) )
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					labels= demo
 | 
				
			||||||
 | 
																																																																																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					labels = None
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					found = []
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					for i in np.arange(CANDIDATE_COUNT) :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					if labels :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					f = sess.run(fake,feed_dict={y:labels})
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					f = sess.run(fake)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					# if we are dealing with numeric values only we can perform a simple marginal sum against the indexes
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					# The code below will insure we have some acceptable cardinal relationships between id and synthetic values
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					df =						( pd.DataFrame(np.round(f).astype(np.int32)))
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					p = 0 not in df.sum(axis=1).values
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					x = df.sum(axis=1).values
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					if np.divide( np.sum(x), x.size) > .9:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					found.append(df)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					if len(found) == NTH_VALID_CANDIDATE or i == CANDIDATE_COUNT:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																																																					break
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					continue
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# i = df.T.index.astype(np.int32) #-- These are numeric pseudonyms
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# df = (i * df).sum(axis=1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# In case we are dealing with actual values like diagnosis codes we can perform 
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																					INDEX =np.random.choice(np.arange(len(found)),1)[0]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#df = found[np.random.choice(np.arange(len(found)),1)[0]]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					df = found[INDEX]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					columns = self.ATTRIBUTES['synthetic'] if isinstance(self.ATTRIBUTES['synthetic'],list)else [self.ATTRIBUTES['synthetic']]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# r = np.zeros((self.ROW_COUNT,len(columns)))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					r = np.zeros(self.ROW_COUNT)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					df.columns = self.values
 | 
				
			||||||
 | 
																																																																																																																																																																																																					if len(found):
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					print (len(found),NTH_VALID_CANDIDATE)			
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					# x = df * self.values 
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					# let's get the rows with no values synthesized (for whatever reason)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					ii = df.apply(lambda row: np.sum(row) == 0,axis=1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					if np.sum(ii) > 0 :
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																					missing = np.repeat(np.nan, np.where(ii==1)[0].size)
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																					missing = []
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					print (len (missing), df.shape)	
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					i = np.where(ii == 0)[0]
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					df =						pd.DataFrame( df.iloc[i].apply(lambda row: self.values[np.random.choice(np.where(row == 1)[0],1)[0]] ,axis=1))
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					df.columns = columns
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																					df = df[columns[0]].append(pd.Series(missing))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																						
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					tf.compat.v1.reset_default_graph()
 | 
				
			||||||
 | 
																																																																																																																																					df = pd.DataFrame(df)
 | 
				
			||||||
 | 
																																																																																																																																					df.columns = columns
 | 
				
			||||||
 | 
																																																																																																																																					print (df.head())
 | 
				
			||||||
 | 
																																																																																																																																					print (df.shape)
 | 
				
			||||||
 | 
																																																																																																																																					return df.to_dict(orient='list')
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# return df.to_dict(orient='list')
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# count = str(len(os.listdir(self.out_dir)))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# _name = os.sep.join([self.out_dir,self.CONTEXT+'-'+count+'.csv'])
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# df.to_csv(_name,index=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# output.extend(np.round(f))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
																																																																																																																																																																																																					# for m in range(2):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																															for n in range(2, self.NUM_LABELS):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															idx1 = (demo[:, m] == 1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															idx2 = (demo[:, n] == 1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															idx = [idx1[j] and idx2[j] for j in range(len(idx1))]
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															num = np.sum(idx)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print ("___________________list__")
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print (idx1)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print (idx2)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print (idx)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print (num)
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															print ("_____________________")
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															nbatch = int(np.ceil(num / self.BATCHSIZE_PER_GPU))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															label_input = np.zeros((nbatch*self.BATCHSIZE_PER_GPU, self.NUM_LABELS))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															label_input[:, n] = 1
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															label_input[:, m] = 1
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															output = []
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															for i in range(nbatch):
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																																																																																															f = sess.run(fake,feed_dict={y: label_input[i* self.BATCHSIZE_PER_GPU:(i+1)* self.BATCHSIZE_PER_GPU]})
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																																																																																															output.extend(np.round(f))
 | 
				
			||||||
 | 
																																																																																																																																																																																																					#																																																																																																																																															output = np.array(output)[:num]
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					# print ([m,n,output])
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																																																																																																																																																																																																																					# np.save(self.out_dir + str(m) + str(n), output)
 | 
				
			||||||
 | 
																																																																					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__' :
 | 
				
			||||||
 | 
																																																																					#
 | 
				
			||||||
 | 
																																																																					# Now we get things done ...
 | 
				
			||||||
 | 
																																																																					column																																																																		= SYS_ARGS['column']
 | 
				
			||||||
 | 
																																																																					column_id																																																															= SYS_ARGS['id'] if 'id' in SYS_ARGS else 'person_id'
 | 
				
			||||||
 | 
																																																																					column_id																																																															= column_id.split(',') if ',' in column_id else column_id
 | 
				
			||||||
 | 
																																																																					df = pd.read_csv(SYS_ARGS['raw-data'])		
 | 
				
			||||||
 | 
																																																																					LABEL = pd.get_dummies(df[column_id]).astype(np.float32).values
 | 
				
			||||||
 | 
																																																																					
 | 
				
			||||||
 | 
																																																																					context																																																																	= SYS_ARGS['raw-data'].split(os.sep)[-1:][0][:-4]
 | 
				
			||||||
 | 
																																																																					if set(['train','learn']) & set(SYS_ARGS.keys()):
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					df = pd.read_csv(SYS_ARGS['raw-data'])			
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					# cols = SYS_ARGS['column']
 | 
				
			||||||
 | 
																																																																																																																																					# _map,_df = (Binary()).Export(df)
 | 
				
			||||||
 | 
																																																																																																																																					# i = np.arange(_map[column]['start'],_map[column]['end'])
 | 
				
			||||||
 | 
																																																																																																																																					max_epochs = np.int32(SYS_ARGS['max_epochs']) if 'max_epochs' in SYS_ARGS else 10
 | 
				
			||||||
 | 
																																																																																																																																					# REAL																		= _df[:,i]
 | 
				
			||||||
 | 
																																																																																																																																					REAL																						= pd.get_dummies(df[column]).astype(np.float32).values
 | 
				
			||||||
 | 
																																																																																																																																					LABEL = pd.get_dummies(df[column_id]).astype(np.float32).values
 | 
				
			||||||
 | 
																																																																																																																																					trainer = Train(context=context,max_epochs=max_epochs,real=REAL,label=LABEL,column=column,column_id=column_id)
 | 
				
			||||||
 | 
																																																																																																																																					trainer.apply()
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																						
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																					# We should train upon this data
 | 
				
			||||||
 | 
																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																					# -- we need to convert the data-frame to binary matrix, given a column
 | 
				
			||||||
 | 
																																																																																																																																					#
 | 
				
			||||||
 | 
																																																																																																																																					pass
 | 
				
			||||||
 | 
																																																																					elif 'generate' in SYS_ARGS:
 | 
				
			||||||
 | 
																																																																																																																																					values = df[column].unique().tolist()
 | 
				
			||||||
 | 
																																																																																																																																					values.sort()
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					p = Predict(context=context,label=LABEL,values=values,column=column)
 | 
				
			||||||
 | 
																																																																																																																																					p.load_meta(column)
 | 
				
			||||||
 | 
																																																																																																																																					r = p.apply()
 | 
				
			||||||
 | 
																																																																																																																																					print (df)
 | 
				
			||||||
 | 
																																																																																																																																					print ()
 | 
				
			||||||
 | 
																																																																																																																																					df[column] = r[column]
 | 
				
			||||||
 | 
																																																																																																																																					print (df)
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																																																																																					
 | 
				
			||||||
 | 
																																																																					else:
 | 
				
			||||||
 | 
																																																																																																																																					print (SYS_ARGS.keys())
 | 
				
			||||||
 | 
																																																																																																																																					print (__doc__)
 | 
				
			||||||
 | 
																																																																					pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue