@ -100,6 +100,13 @@ class GNet :
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . TOTAL_BATCHSIZE  =  self . BATCHSIZE_PER_GPU  *  self . NUM_GPUS 
 
					 
					 
					 
					                self . TOTAL_BATCHSIZE  =  self . BATCHSIZE_PER_GPU  *  self . NUM_GPUS 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . STEPS_PER_EPOCH  =  256  #int(np.load('ICD9/train.npy').shape[0] / 2000)        
 
					 
					 
					 
					                self . STEPS_PER_EPOCH  =  256  #int(np.load('ICD9/train.npy').shape[0] / 2000)        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . MAX_EPOCHS  =  10  if  ' max_epochs '  not  in  args  else  int ( args [ ' max_epochs ' ] ) 
 
					 
					 
					 
					                self . MAX_EPOCHS  =  10  if  ' max_epochs '  not  in  args  else  int ( args [ ' max_epochs ' ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                CHECKPOINT_SKIPS  =  10 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                if  self . MAX_EPOCHS   <  2 * CHECKPOINT_SKIPS  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        CHECKPOINT_SKIPS  =  2 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                self . CHECKPOINTS  =  np . repeat (  np . divide ( self . MAX_EPOCHS , CHECKPOINT_SKIPS ) , CHECKPOINT_SKIPS  ) . cumsum ( ) . astype ( int ) . tolist ( )  
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . ROW_COUNT  =  args [ ' real ' ] . shape [ 0 ]  if  ' real '  in  args  else  100 
 
					 
					 
					 
					                self . ROW_COUNT  =  args [ ' real ' ] . shape [ 0 ]  if  ' real '  in  args  else  100 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . CONTEXT  =  args [ ' context ' ] 
 
					 
					 
					 
					                self . CONTEXT  =  args [ ' context ' ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . ATTRIBUTES  =  { " id " : args [ ' column_id ' ]  if  ' column_id '  in  args  else  None , " synthetic " : args [ ' column ' ]  if  ' column '  in  args  else  None } 
 
					 
					 
					 
					                self . ATTRIBUTES  =  { " id " : args [ ' column_id ' ]  if  ' column_id '  in  args  else  None , " synthetic " : args [ ' column ' ]  if  ' column '  in  args  else  None } 
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -120,14 +127,18 @@ class GNet :
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                for  key  in  [ ' train ' , ' output ' ]  : 
 
					 
					 
					 
					                for  key  in  [ ' train ' , ' output ' ]  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        self . mkdir ( os . sep . join ( [ self . log_dir , key ] ) ) 
 
					 
					 
					 
					                        self . mkdir ( os . sep . join ( [ self . log_dir , key ] ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        self . mkdir  ( os . sep . join ( [ self . log_dir , key , self . CONTEXT ] ) ) 
 
					 
					 
					 
					                        self . mkdir  ( os . sep . join ( [ self . log_dir , key , self . CONTEXT ] ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        if  ' partition '  in  args  : 
 
					 
					 
					 
					                        # if 'partition' in args : 
 
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					 
					                                self . mkdir  ( os . sep . join ( [ self . log_dir , key , self . CONTEXT , str ( args [ ' partition ' ] ) ] ) ) 
 
					 
					 
					 
					                        #        self.mkdir (os.sep.join([self.log_dir,key,self.CONTEXT,str(args['partition'])])) 
 
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					 
					                        
 
					 
					 
					 
					 
				
			 
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                self . train_dir   =  os . sep . join ( [ self . log_dir , ' train ' , self . CONTEXT ] )                 
 
					 
					 
					 
					                self . train_dir   =  os . sep . join ( [ self . log_dir , ' train ' , self . CONTEXT ] )                 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . out_dir  =  os . sep . join ( [ self . log_dir , ' output ' , self . CONTEXT ] ) 
 
					 
					 
					 
					                self . out_dir  =  os . sep . join ( [ self . log_dir , ' output ' , self . CONTEXT ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                if  ' partition '  in  args  : 
 
					 
					 
					 
					                if  ' partition '  in  args  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        self . train_dir  =  os . sep . join ( [ self . train_dir , str ( args [ ' partition ' ] ) ] ) 
 
					 
					 
					 
					                        self . train_dir  =  os . sep . join ( [ self . train_dir , str ( args [ ' partition ' ] ) ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        self . out_dir  =  os . sep . join ( [ self . out_dir , str ( args [ ' partition ' ] ) ] ) 
 
					 
					 
					 
					                        self . out_dir  =  os . sep . join ( [ self . out_dir , str ( args [ ' partition ' ] ) ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                for  checkpoint  in  self . CHECKPOINTS  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        self . mkdir  ( os . sep . join ( [ self . train_dir , str ( checkpoint ) ] ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        self . mkdir  ( os . sep . join ( [ self . out_dir , str ( checkpoint ) ] ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # if self.logger : 
 
					 
					 
					 
					                # if self.logger : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        
 
					 
					 
					 
					                        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                #         We will clear the logs from the data-store  
 
					 
					 
					 
					                #         We will clear the logs from the data-store  
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -150,12 +161,13 @@ class GNet :
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        attr  =  json . loads ( ( open ( _name ) ) . read ( ) ) 
 
					 
					 
					 
					                        attr  =  json . loads ( ( open ( _name ) ) . read ( ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        for  key  in  attr  : 
 
					 
					 
					 
					                        for  key  in  attr  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                value  =  attr [ key ] 
 
					 
					 
					 
					                                value  =  attr [ key ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                if  not  hasattr ( self , key ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        setattr ( self , key , value ) 
 
					 
					 
					 
					                                        setattr ( self , key , value ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . train_dir   =  os . sep . join ( [ self . log_dir , ' train ' , self . CONTEXT ] )                 
 
					 
					 
					 
					                self . train_dir   =  os . sep . join ( [ self . log_dir , ' train ' , self . CONTEXT ] )                 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . out_dir  =  os . sep . join ( [ self . log_dir , ' output ' , self . CONTEXT ] ) 
 
					 
					 
					 
					                self . out_dir  =  os . sep . join ( [ self . log_dir , ' output ' , self . CONTEXT ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                if  ' partition '  in  args   :
 
					 
					 
					 
					                # if 'partition' in args   :
 
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					 
					                         self . train_dir   =   os . sep . join ( [ self . train_dir , str ( args [ ' partition ' ] ) ] )
 
					 
					 
					 
					                #           self.train_dir = os.sep.join([self.train_dir,str(args['partition'])])
 
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					 
					                         self . out_dir   =   os . sep . join ( [ self . out_dir , str ( args [ ' partition ' ] ) ] )
 
					 
					 
					 
					                #           self.out_dir = os.sep.join([self.out_dir,str(args['partition'])])
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                                
 
					 
					 
					 
					                                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        
 
					 
					 
					 
					                        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  log_meta ( self , * * args )  : 
 
					 
					 
					 
					        def  log_meta ( self , * * args )  : 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -183,15 +195,24 @@ class GNet :
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                suffix  =  self . CONTEXT  #self.get.suffix() 
 
					 
					 
					 
					                suffix  =  self . CONTEXT  #self.get.suffix() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                _name  =  os . sep . join ( [ self . out_dir , ' meta- ' + suffix ] ) 
 
					 
					 
					 
					                _name  =  os . sep . join ( [ self . out_dir , ' meta- ' + suffix ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                
 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                f  =  open ( _name + ' .json ' , ' w ' ) 
 
					 
					 
					 
					                # f = open(_name+'.json','w') 
 
				
			 
			
				
				
			
		
	
		
		
			
				
					
					 
					 
					 
					                f . write ( json . dumps ( _object ) ) 
 
					 
					 
					 
					                # f.write(json.dumps(_object)) 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # f.close() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                for  _info  in  [ { " name " : os . sep . join ( [ self . out_dir , ' meta- ' + suffix + ' .json ' ] ) , " data " : _object } , { " name " : os . sep . join ( [ self . out_dir , ' epochs.json ' ] ) , " data " : self . logs [ ' epochs ' ]  if  ' epochs '  in  self . logs  else  [ ] } ]  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        f  =  open ( _info [ ' name ' ] , ' w ' ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        f . write ( json . dumps ( _info [ ' data ' ] ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        f . close ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                return  _object 
 
					 
					 
					 
					                return  _object 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  mkdir  ( self , path ) : 
 
					 
					 
					 
					        def  mkdir  ( self , path ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                if  not  os . path . exists ( path )  : 
 
					 
					 
					 
					                if  not  os . path . exists ( path )  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        if  os . sep  in  path  : 
 
					 
					 
					 
					                        if  os . sep  in  path  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                pass 
 
					 
					 
					 
					                                pass 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                root  =  [ ] 
 
					 
					 
					 
					                                root  =  [ ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                for  loc  in  path . split ( os . sep )  : 
 
					 
					 
					 
					                                
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                for  loc  in  path . strip ( ) . split ( os . sep )  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                        if  loc  ==  ' '  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                root . append ( os . sep ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        root . append ( loc ) 
 
					 
					 
					 
					                                        root . append ( loc ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        if  not  os . path . exists ( os . sep . join ( root ) )  :                                                 
 
					 
					 
					 
					                                        if  not  os . path . exists ( os . sep . join ( root ) )  :                                                 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                os . mkdir ( os . sep . join ( root ) ) 
 
					 
					 
					 
					                                                os . mkdir ( os . sep . join ( root ) ) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -278,8 +299,10 @@ class Generator (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                tf . compat . v1 . add_to_collection ( ' glosses ' ,  loss ) 
 
					 
					 
					 
					                tf . compat . v1 . add_to_collection ( ' glosses ' ,  loss ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                return  loss ,  loss                 
 
					 
					 
					 
					                return  loss ,  loss                 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  load_meta ( self ,  * * args ) : 
 
					 
					 
					 
					        def  load_meta ( self ,  * * args ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                super ( ) . load_meta ( * * args  )
 
					 
					 
					 
					                # super().load_meta(**args  )
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                self . discriminator . load_meta ( * * args ) 
 
					 
					 
					 
					                self . discriminator . load_meta ( * * args ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					               
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  network ( self , * * args )  : 
 
					 
					 
					 
					        def  network ( self , * * args )  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                """ 
 
					 
					 
					 
					                """ 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                This  function  will  build  the  network  that  will  generate  the  synthetic  candidates 
 
					 
					 
					 
					                This  function  will  build  the  network  that  will  generate  the  synthetic  candidates 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -381,6 +404,7 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        
 
					 
					 
					 
					                        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        self . logger . write ( { " module " : " gan-train " , " action " : " start " , " input " : { " partition " : self . PARTITION , " meta " : self . meta }  }  ) 
 
					 
					 
					 
					                        self . logger . write ( { " module " : " gan-train " , " action " : " start " , " input " : { " partition " : self . PARTITION , " meta " : self . meta }  }  ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                
 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # self.log (real_shape=list(self._REAL.shape),label_shape = self._LABEL.shape,meta_data=self.meta) 
 
					 
					 
					 
					                # self.log (real_shape=list(self._REAL.shape),label_shape = self._LABEL.shape,meta_data=self.meta) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  load_meta ( self ,  column ) : 
 
					 
					 
					 
					        def  load_meta ( self ,  column ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                """ 
 
					 
					 
					 
					                """ 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -445,7 +469,7 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                else  : 
 
					 
					 
					 
					                else  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        dataset  =  tf . data . Dataset . from_tensor_slices ( features_placeholder ) 
 
					 
					 
					 
					                        dataset  =  tf . data . Dataset . from_tensor_slices ( features_placeholder ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # labels_placeholder = None 
 
					 
					 
					 
					                # labels_placeholder = None 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . repeat ( 1 0000) 
 
					 
					 
					 
					                dataset  =  dataset . repeat ( 2 0000) 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                
 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . batch ( batch_size = self . BATCHSIZE_PER_GPU ) 
 
					 
					 
					 
					                dataset  =  dataset . batch ( batch_size = self . BATCHSIZE_PER_GPU ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                dataset  =  dataset . prefetch ( 1 ) 
 
					 
					 
					 
					                dataset  =  dataset . prefetch ( 1 ) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -472,9 +496,11 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                if  self . _LABEL  is  not  None  : 
 
					 
					 
					 
					                                                if  self . _LABEL  is  not  None  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        ( real ,  label )  =  iterator . get_next ( ) 
 
					 
					 
					 
					                                                        ( real ,  label )  =  iterator . get_next ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                else : 
 
					 
					 
					 
					                                                else : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                        
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        real  =  iterator . get_next ( ) 
 
					 
					 
					 
					                                                        real  =  iterator . get_next ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        label =  None 
 
					 
					 
					 
					                                                        label =  None 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                loss ,  w  =  self . loss ( scope = scope ,  stage = stage ,  real = real ,  label = label ) 
 
					 
					 
					 
					                                                loss ,  w  =  self . loss ( scope = scope ,  stage = stage ,  real = real ,  label = label ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                #tf.get_variable_scope().reuse_variables() 
 
					 
					 
					 
					                                                #tf.get_variable_scope().reuse_variables() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                tf . compat . v1 . get_variable_scope ( ) . reuse_variables ( ) 
 
					 
					 
					 
					                                                tf . compat . v1 . get_variable_scope ( ) . reuse_variables ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) 
 
					 
					 
					 
					                                                #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -507,6 +533,7 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        # init    = tf.global_variables_initializer() 
 
					 
					 
					 
					                        # init    = tf.global_variables_initializer() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        init     =  tf . compat . v1 . global_variables_initializer ( ) 
 
					 
					 
					 
					                        init     =  tf . compat . v1 . global_variables_initializer ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        logs  =  [ ] 
 
					 
					 
					 
					                        logs  =  [ ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        self . logs [ ' epochs ' ]  =  [ ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: 
 
					 
					 
					 
					                        #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        with  tf . compat . v1 . Session ( config = tf . compat . v1 . ConfigProto ( allow_soft_placement = True ,  log_device_placement = False ) )  as  sess : 
 
					 
					 
					 
					                        with  tf . compat . v1 . Session ( config = tf . compat . v1 . ConfigProto ( allow_soft_placement = True ,  log_device_placement = False ) )  as  sess : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                
 
					 
					 
					 
					                                
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -536,25 +563,41 @@ class Train (GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        logs . append ( { " epoch " :  int ( epoch ) , " distance " : float ( - w_sum / ( self . STEPS_PER_EPOCH * 2 ) )  } ) 
 
					 
					 
					 
					                                        logs . append ( { " epoch " :  int ( epoch ) , " distance " : float ( - w_sum / ( self . STEPS_PER_EPOCH * 2 ) )  } ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        # if epoch % self.MAX_EPOCHS == 0: 
 
					 
					 
					 
					                                        # if epoch % self.MAX_EPOCHS == 0: 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                        if  epoch  in  [ 5 , 10 , 20 , 50 , 75 ,  self . MAX_EPOCHS ]  : 
 
					 
					 
					 
					                                        # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                        if  epoch  in  self . CHECKPOINTS   or  int ( epoch )  ==  1 : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] 
 
					 
					 
					 
					                                                # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                suffix  =  self . CONTEXT  #self.get.suffix() 
 
					 
					 
					 
					                                                suffix  =  self . CONTEXT  #self.get.suffix() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                _name   =  os . sep . join ( [ self . train_dir , suffix ] ) 
 
					 
					 
					 
					                                                _name   =  os . sep . join ( [ self . train_dir , str ( epoch ) , suffix ] ) 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					                                                # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) 
 
					 
					 
					 
					                                                # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                saver . save ( sess ,  _name ,  write_meta_graph = False ,  global_step = epoch ) 
 
					 
					 
					 
					                                                saver . save ( sess ,  _name ,  write_meta_graph = False ,  global_step = epoch ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                # 
 
					 
					 
					 
					                                                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                # 
 
					 
					 
					 
					                                                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                if  self . logger  : 
 
					 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        row  =  { " module " : " gan-train " , " action " : " logs " , " input " : { " partition " : self . PARTITION , " logs " : logs } }  #,"model":pickle.dump(sess)}                                                         
 
					 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        self . logger . write ( row ) 
 
					 
					 
					 
					 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                               
 
					 
					 
					 
					                                               
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                logs  =  [ { " path " : _name , " epochs " : int ( epoch ) , " loss " : float ( - w_sum / ( self . STEPS_PER_EPOCH * 2 ) ) } ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                if  self . logger  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                        # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}                                                         
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                        # self.logger.write(row) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                                                        self . logs [ ' epochs ' ]  + =  logs 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        # 
 
					 
					 
					 
					                                                        # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        # @TODO: 
 
					 
					 
					 
					                                                        # @TODO: 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        # We should upload the files in the checkpoint  
 
					 
					 
					 
					                                                        # We should upload the files in the checkpoint  
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        # This would allow the learnt model to be portable to another system 
 
					 
					 
					 
					                                                        # This would allow the learnt model to be portable to another system 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                                                        # 
 
					 
					 
					 
					                                                        # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                        tf . compat . v1 . reset_default_graph ( ) 
 
					 
					 
					 
					                        tf . compat . v1 . reset_default_graph ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # let's sort the epochs we've logged thus far (if any) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                self . logs [ ' epochs ' ] . sort ( key = lambda  _item :  _item [ ' loss ' ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                if  self . logger  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        _log  =  { ' module ' : ' gan-train ' , ' action ' : ' epochs ' , ' input ' : self . logs [ ' epochs ' ] } 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        self . logger . write ( _log ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                
 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # @TODO: 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # Make another copy of this on disk to be able to load it should we not have a logger setup 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                self . log_meta ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					class  Predict ( GNet ) : 
 
					 
					 
					 
					class  Predict ( GNet ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        """ 
 
					 
					 
					 
					        """ 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        This  class  uses  synthetic  data  given  a  learned  model 
 
					 
					 
					 
					        This  class  uses  synthetic  data  given  a  learned  model 
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -565,6 +608,7 @@ class Predict(GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . values      =  args [ ' values ' ] 
 
					 
					 
					 
					                self . values      =  args [ ' values ' ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . ROW_COUNT   =  args [ ' row_count ' ] 
 
					 
					 
					 
					                self . ROW_COUNT   =  args [ ' row_count ' ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . oROW_COUNT  =  self . ROW_COUNT 
 
					 
					 
					 
					                self . oROW_COUNT  =  self . ROW_COUNT 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # self.MISSING_VALUES = np.nan_to_num(np.nan) 
 
					 
					 
					 
					                # self.MISSING_VALUES = np.nan_to_num(np.nan) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # if 'no_value' in args and args['no_value'] not in ['na','','NA'] : 
 
					 
					 
					 
					                # if 'no_value' in args and args['no_value'] not in ['na','','NA'] : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                #         self.MISSING_VALUES = args['no_value'] 
 
					 
					 
					 
					                #         self.MISSING_VALUES = args['no_value'] 
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -577,9 +621,20 @@ class Predict(GNet):
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                super ( ) . load_meta ( * * args ) 
 
					 
					 
					 
					                super ( ) . load_meta ( * * args ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . generator . load_meta ( * * args ) 
 
					 
					 
					 
					                self . generator . load_meta ( * * args ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                self . ROW_COUNT  =  self . oROW_COUNT 
 
					 
					 
					 
					                self . ROW_COUNT  =  self . oROW_COUNT 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # updating the input/output for the generator, so it points properly 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                for  object  in  [ self , self . generator ]  : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        _train_dir  =  os . sep . join ( [ self . log_dir , ' train ' , self . CONTEXT , str ( self . MAX_EPOCHS ) ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        _out_dir =  os . sep . join ( [ self . log_dir , ' output ' , self . CONTEXT , str ( self . MAX_EPOCHS ) ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        setattr ( object , ' train_dir ' , _train_dir ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                        setattr ( object , ' out_dir ' , _out_dir )                 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        def  apply ( self , * * args ) : 
 
					 
					 
					 
					        def  apply ( self , * * args ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                suffix  =  self . CONTEXT  #self.get.suffix() 
 
					 
					 
					 
					                suffix  =  self . CONTEXT  #self.get.suffix() 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                model_dir  =  os . sep . join ( [ self . train_dir , suffix + ' - ' + str ( self . MAX_EPOCHS ) ] ) 
 
					 
					 
					 
					                model_dir  =  os . sep . join ( [ self . train_dir , suffix + ' - ' + str ( self . MAX_EPOCHS ) ] ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                # model_dir = os.sep.join([self.train_dir,str(self.MAX_EPOCHS)]) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					               
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                demo  =  self . _LABEL  #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo'] 
 
					 
					 
					 
					                demo  =  self . _LABEL  #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo'] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # 
 
					 
					 
					 
					                # 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					                # setup computational graph 
 
					 
					 
					 
					                # setup computational graph