|  |  | @ -100,13 +100,12 @@ class GNet : | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS |  |  |  |                 self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)        |  |  |  |                 self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)        | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) |  |  |  |                 self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) | 
			
		
	
		
		
			
				
					
					|  |  |  |                 CHECKPOINT_SKIPS = 10 |  |  |  |                 CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10) | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                 if self.MAX_EPOCHS  < 2*CHECKPOINT_SKIPS : |  |  |  |                 # if self.MAX_EPOCHS  < 2*CHECKPOINT_SKIPS : | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                         CHECKPOINT_SKIPS = 2 |  |  |  |                 #         CHECKPOINT_SKIPS = 2 | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                 self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()  |  |  |  |                 # self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()  | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                  |  |  |  |                 self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist() | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                  |  |  |  |                 | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  |  | 
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                 self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100 |  |  |  |                 self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100 | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.CONTEXT = args['context'] |  |  |  |                 self.CONTEXT = args['context'] | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None} |  |  |  |                 self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None} | 
			
		
	
	
		
		
			
				
					|  |  | @ -469,7 +468,7 @@ class Train (GNet): | 
			
		
	
		
		
			
				
					
					|  |  |  |                 else : |  |  |  |                 else : | 
			
		
	
		
		
			
				
					
					|  |  |  |                         dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) |  |  |  |                         dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # labels_placeholder = None |  |  |  |                 # labels_placeholder = None | 
			
		
	
		
		
			
				
					
					|  |  |  |                 dataset = dataset.repeat(80000) |  |  |  |                 dataset = dataset.repeat(800000) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                  |  |  |  |                  | 
			
		
	
		
		
			
				
					
					|  |  |  |                 dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU) |  |  |  |                 dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU) | 
			
		
	
		
		
			
				
					
					|  |  |  |                 dataset = dataset.prefetch(1) |  |  |  |                 dataset = dataset.prefetch(1) | 
			
		
	
	
		
		
			
				
					|  |  | @ -560,39 +559,43 @@ class Train (GNet): | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) |  |  |  |                                         print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         # print (dir (w_distance)) |  |  |  |                                         # print (dir (w_distance)) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) }) |  |  |  |                                         # logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) }) | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  |                                          | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                         suffix = str(self.CONTEXT) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                         _name  = os.sep.join([self.train_dir,str(epoch),suffix]) if epoch in self.CHECKPOINTS else '' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                         _logentry = {"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))} | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         # if epoch % self.MAX_EPOCHS == 0: |  |  |  |                                         # if epoch % self.MAX_EPOCHS == 0: | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : |  |  |  |                                         # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : | 
			
		
	
		
		
			
				
					
					|  |  |  |                                         if epoch in self.CHECKPOINTS : |  |  |  |                                         if epoch in self.CHECKPOINTS : | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] |  |  |  |                                                 # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 suffix = self.CONTEXT #self.get.suffix() |  |  |  |                                                 # suffix = self.CONTEXT #self.get.suffix() | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                 _name  = os.sep.join([self.train_dir,str(epoch),suffix]) |  |  |  |                                                 # _name  = os.sep.join([self.train_dir,str(epoch),suffix]) | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                                                 # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) |  |  |  |                                                 # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch)) |  |  |  |                                                 saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch)) | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                  |  |  |  |                                                  | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 # |  |  |  |                                                 # | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 # |  |  |  |                                                 # | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 |  |  |  |                                                 | 
			
		
	
		
		
			
				
					
					|  |  |  |                                                 logs = [{"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}] |  |  |  |                                                 # logs = [] | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                 if self.logger : |  |  |  |                                                 # if self.logger : | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}                                                         |  |  |  |                                                 #         # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}                                                         | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # self.logger.write(row) |  |  |  |                                                 #         # self.logger.write(row) | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         self.logs['epochs'] += logs |  |  |  |                                                 #         self.logs['epochs'] += logs | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # |  |  |  |                                                 #         # | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # @TODO: |  |  |  |                                                 #         # @TODO: | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # We should upload the files in the checkpoint  |  |  |  |                                                 #         # We should upload the files in the checkpoint  | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                                                         # This would allow the learnt model to be portable to another system |  |  |  |                                                 #         # This would allow the learnt model to be portable to another system | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                                                         # |  |  |  |                                                         # | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                                         self.logs['epochs'].append(_logentry) | 
			
		
	
		
		
			
				
					
					|  |  |  |                         tf.compat.v1.reset_default_graph() |  |  |  |                         tf.compat.v1.reset_default_graph() | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # |  |  |  |                 # | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # let's sort the epochs we've logged thus far (if any) |  |  |  |                 # let's sort the epochs we've logged thus far (if any) | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted |  |  |  |                 # Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # |  |  |  |                 # | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # self.logs['epochs']  = self.logs['epochs'][-5:] |  |  |  |                 # self.logs['epochs']  = self.logs['epochs'][-5:] | 
			
		
	
		
		
			
				
					
					|  |  |  |                 self.logs['epochs'].sort(key=lambda _item: _item['loss']) |  |  |  |                  | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                 if self.logger : |  |  |  |                 if self.logger : | 
			
		
	
		
		
			
				
					
					|  |  |  |                         _log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']} |  |  |  |                         _log = {'module':'gan-train','context':self.CONTEXT,'action':'epochs','input':self.logs['epochs']} | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |                         self.logger.write(_log) |  |  |  |                         self.logger.write(_log) | 
			
		
	
		
		
			
				
					
					|  |  |  |                  |  |  |  |                  | 
			
		
	
		
		
			
				
					
					|  |  |  |                 # |  |  |  |                 # | 
			
		
	
	
		
		
			
				
					|  |  | 
 |