| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -59,8 +59,8 @@ class GNet :
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.logs = {}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                if self.NUM_GPUS > 1 :
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    os.environ['CUDA_VISIBLE_DEVICES'] = "4"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                # if self.NUM_GPUS > 1 :
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                #     os.environ['CUDA_VISIBLE_DEVICES'] = "4"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -78,9 +78,12 @@ class GNet :
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                if 'real' in args : 
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                # self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size'])
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                PROPOSED_BATCH_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                if args['real'].shape[0]  < PROPOSED_BATCH_PER_GPU :
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                        self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) 
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                else:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                        self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                # self.BATCHSIZE_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)       
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -254,7 +257,7 @@ class Generator (GNet):
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                x               = args['inputs']
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                label   = args['label']
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                print (self.NUM_LABELS)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                        for i, dim in enumerate(self.G_STRUCTURE[:-1]):
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |