@ -59,20 +59,27 @@ class GNet :
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . logs  =  { } 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . NUM_GPUS  =  1  if  ' num_gpu '  not  in  args  else  args [ ' num_gpu ' ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					       
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                if  self . NUM_GPUS  >  1  : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    os . environ [ ' CUDA_VISIBLE_DEVICES ' ]  =  " 4 " 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . X_SPACE_SIZE  =  args [ ' real ' ] . shape [ 1 ]  if  ' real '  in  args  else  854 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . G_STRUCTURE  =  [ 128 , 128 ]  #[self.X_SPACE_SIZE, self.X_SPACE_SIZE] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . D_STRUCTURE  =  [ self . X_SPACE_SIZE , 256 , 128 ]  #[self.X_SPACE_SIZE, self.X_SPACE_SIZE*2, self.X_SPACE_SIZE] #-- change 854 to number of diagnosis 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                # self.NUM_LABELS         = 8 if 'label' not in args elif len(args['label'].shape) args['label'].shape[1] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                if  ' label '  in  args  and  len ( args [ ' label ' ] . shape )  ==  2  : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        self . NUM_LABELS  =  args [ ' label ' ] . shape [ 1 ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                elif  ' label '  in  args  and  len ( args [ ' label ' ] )  ==  1  : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        self . NUM_LABELS  =  args [ ' label ' ] . shape [ 0 ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                else : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        self . NUM_LABELS  =  8 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . Z_DIM  =  128  #self.X_SPACE_SIZE      
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . BATCHSIZE_PER_GPU  =  args [ ' real ' ] . shape [ 0 ]  if  ' real '  in  args  else  256 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                # self.Z_DIM = 128 #self.X_SPACE_SIZE      
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . Z_DIM  =  128   #-- used as rows down stream 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . G_STRUCTURE  =  [ self . Z_DIM , self . Z_DIM ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                if  ' real '  in  args  :  
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                self . D_STRUCTURE  =  [ args [ ' real ' ] . shape [ 1 ] , 256 , self . Z_DIM ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . BATCHSIZE_PER_GPU  =  int ( args [ ' real ' ] . shape [ 0 ] *  1 )  if  ' real '  in  args  else  256 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . TOTAL_BATCHSIZE  =  self . BATCHSIZE_PER_GPU  *  self . NUM_GPUS 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . STEPS_PER_EPOCH  =  256  #int(np.load('ICD9/train.npy').shape[0] / 2000)        
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . MAX_EPOCHS  =  10  if  ' max_epochs '  not  in  args  else  int ( args [ ' max_epochs ' ] ) 
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -533,6 +540,8 @@ class Predict(GNet):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                # The code below will insure we have some acceptable cardinal relationships between id and synthetic values 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                # 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                df  =   (  pd . DataFrame ( np . round ( f ) . astype ( np . int32 ) ) ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                print  ( df . head ( ) ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                print  ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                p  =  0  not  in  df . sum ( axis = 1 ) . values 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                if       p :