@ -79,7 +79,8 @@ class GNet :
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                if  ' real '  in  args  :  
 
					 
					 
					 
					                if  ' real '  in  args  :  
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                self . D_STRUCTURE  =  [ args [ ' real ' ] . shape [ 1 ] , 256 , self . Z_DIM ] 
 
					 
					 
					 
					                                self . D_STRUCTURE  =  [ args [ ' real ' ] . shape [ 1 ] , 256 , self . Z_DIM ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . BATCHSIZE_PER_GPU  =  int ( args [ ' real ' ] . shape [ 0 ] *  1 )  if  ' real '  in  args  else  256 
 
					 
					 
					 
					                # self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                self . BATCHSIZE_PER_GPU  =  3000  if  ' batch_size '  not  in  args  else  int ( args [ ' batch_size ' ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . TOTAL_BATCHSIZE  =  self . BATCHSIZE_PER_GPU  *  self . NUM_GPUS 
 
					 
					 
					 
					                self . TOTAL_BATCHSIZE  =  self . BATCHSIZE_PER_GPU  *  self . NUM_GPUS 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . STEPS_PER_EPOCH  =  256  #int(np.load('ICD9/train.npy').shape[0] / 2000)        
 
					 
					 
					 
					                self . STEPS_PER_EPOCH  =  256  #int(np.load('ICD9/train.npy').shape[0] / 2000)        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . MAX_EPOCHS  =  10  if  ' max_epochs '  not  in  args  else  int ( args [ ' max_epochs ' ] ) 
 
					 
					 
					 
					                self . MAX_EPOCHS  =  10  if  ' max_epochs '  not  in  args  else  int ( args [ ' max_epochs ' ] ) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -410,7 +411,7 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        dataset  =  tf . data . Dataset . from_tensor_slices ( features_placeholder ) 
 
					 
					 
					 
					                        dataset  =  tf . data . Dataset . from_tensor_slices ( features_placeholder ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # labels_placeholder = None 
 
					 
					 
					 
					                # labels_placeholder = None 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . repeat ( 10000 ) 
 
					 
					 
					 
					                dataset  =  dataset . repeat ( 10000 ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . batch ( batch_size = 3000 ) 
 
					 
					 
					 
					                dataset  =  dataset . batch ( batch_size = self . BATCHSIZE_PER_GPU ) 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . prefetch ( 1 ) 
 
					 
					 
					 
					                dataset  =  dataset . prefetch ( 1 ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # iterator = dataset.make_initializable_iterator() 
 
					 
					 
					 
					                # iterator = dataset.make_initializable_iterator() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                iterator  =  tf . compat . v1 . data . make_initializable_iterator ( dataset ) 
 
					 
					 
					 
					                iterator  =  tf . compat . v1 . data . make_initializable_iterator ( dataset ) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -430,7 +431,8 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        ( real ,  label )  =  iterator . get_next ( ) 
 
					 
					 
					 
					                                                        ( real ,  label )  =  iterator . get_next ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                else : 
 
					 
					 
					 
					                                                else : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        real  =  iterator . get_next ( ) 
 
					 
					 
					 
					                                                        real  =  iterator . get_next ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                loss ,  w  =  self . loss ( scope = scope ,  stage = stage ,  real = self . _REAL ,  label = self . _LABEL ) 
 
					 
					 
					 
					                                                        label =  None 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                loss ,  w  =  self . loss ( scope = scope ,  stage = stage ,  real = real ,  label = label ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                #tf.get_variable_scope().reuse_variables() 
 
					 
					 
					 
					                                                #tf.get_variable_scope().reuse_variables() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                tf . compat . v1 . get_variable_scope ( ) . reuse_variables ( ) 
 
					 
					 
					 
					                                                tf . compat . v1 . get_variable_scope ( ) . reuse_variables ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) 
 
					 
					 
					 
					                                                #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -465,6 +467,7 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        logs  =  [ ] 
 
					 
					 
					 
					                        logs  =  [ ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: 
 
					 
					 
					 
					                        #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        with  tf . compat . v1 . Session ( config = tf . compat . v1 . ConfigProto ( allow_soft_placement = True ,  log_device_placement = False ) )  as  sess : 
 
					 
					 
					 
					                        with  tf . compat . v1 . Session ( config = tf . compat . v1 . ConfigProto ( allow_soft_placement = True ,  log_device_placement = False ) )  as  sess : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                sess . run ( init ) 
 
					 
					 
					 
					                                sess . run ( init ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                
 
					 
					 
					 
					                                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                sess . run ( iterator_d . initializer , 
 
					 
					 
					 
					                                sess . run ( iterator_d . initializer ,