| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -14,6 +14,7 @@ import sys
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					from data.params import SYS_ARGS
 | 
					 | 
					 | 
					 | 
					from data.params import SYS_ARGS
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					from data.bridge import Binary
 | 
					 | 
					 | 
					 | 
					from data.bridge import Binary
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					import json
 | 
					 | 
					 | 
					 | 
					import json
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					import pickle
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
 | 
					 | 
					 | 
					 | 
					os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					os.environ['CUDA_VISIBLE_DEVICES'] = "0"
 | 
					 | 
					 | 
					 | 
					os.environ['CUDA_VISIBLE_DEVICES'] = "0"
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -38,7 +39,7 @@ class GNet :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.layers.normalize = self.normalize
 | 
					 | 
					 | 
					 | 
					        self.layers.normalize = self.normalize
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.NUM_GPUS = 1
 | 
					 | 
					 | 
					 | 
					        self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					       
 | 
					 | 
					 | 
					 | 
					       
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
 | 
					 | 
					 | 
					 | 
					        self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -64,8 +65,8 @@ class GNet :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.get = void()
 | 
					 | 
					 | 
					 | 
					        self.get = void()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.get.variables = self._variable_on_cpu
 | 
					 | 
					 | 
					 | 
					        self.get.variables = self._variable_on_cpu
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
					 | 
					 | 
					 | 
					        self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.logger = args['logger'] if 'logger' in args and args['logger'] else None
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.init_logs(**args)
 | 
					 | 
					 | 
					 | 
					        self.init_logs(**args)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def init_logs(self,**args):
 | 
					 | 
					 | 
					 | 
					    def init_logs(self,**args):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -98,7 +99,7 @@ class GNet :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                
 | 
					 | 
					 | 
					 | 
					                
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            
 | 
					 | 
					 | 
					 | 
					            
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def log_meta(self,**args) :
 | 
					 | 
					 | 
					 | 
					    def log_meta(self,**args) :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        object = {
 | 
					 | 
					 | 
					 | 
					        _object = {
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            'CONTEXT':self.CONTEXT,
 | 
					 | 
					 | 
					 | 
					            'CONTEXT':self.CONTEXT,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            'ATTRIBUTES':self.ATTRIBUTES,
 | 
					 | 
					 | 
					 | 
					            'ATTRIBUTES':self.ATTRIBUTES,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
 | 
					 | 
					 | 
					 | 
					            'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -120,7 +121,8 @@ class GNet :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        _name = os.sep.join([self.out_dir,'meta-'+suffix])
 | 
					 | 
					 | 
					 | 
					        _name = os.sep.join([self.out_dir,'meta-'+suffix])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        
 | 
					 | 
					 | 
					 | 
					        
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        f = open(_name+'.json','w')
 | 
					 | 
					 | 
					 | 
					        f = open(_name+'.json','w')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        f.write(json.dumps(object))
 | 
					 | 
					 | 
					 | 
					        f.write(json.dumps(_object))
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        return _object
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def mkdir (self,path):
 | 
					 | 
					 | 
					 | 
					    def mkdir (self,path):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if not os.path.exists(path) :
 | 
					 | 
					 | 
					 | 
					        if not os.path.exists(path) :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            os.mkdir(path)        
 | 
					 | 
					 | 
					 | 
					            os.mkdir(path)        
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -295,7 +297,7 @@ class Train (GNet):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.column = args['column']
 | 
					 | 
					 | 
					 | 
					        self.column = args['column']
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        # print ([" *** ",self.BATCHSIZE_PER_GPU])
 | 
					 | 
					 | 
					 | 
					        # print ([" *** ",self.BATCHSIZE_PER_GPU])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        
 | 
					 | 
					 | 
					 | 
					        
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.log_meta()
 | 
					 | 
					 | 
					 | 
					        self.meta = self.log_meta()
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def load_meta(self, column):
 | 
					 | 
					 | 
					 | 
					    def load_meta(self, column):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        """
 | 
					 | 
					 | 
					 | 
					        """
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        This function will delegate the calls to load meta data to it's dependents
 | 
					 | 
					 | 
					 | 
					        This function will delegate the calls to load meta data to it's dependents
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -393,7 +395,7 @@ class Train (GNet):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            # saver = tf.train.Saver()
 | 
					 | 
					 | 
					 | 
					            # saver = tf.train.Saver()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            saver   = tf.compat.v1.train.Saver()
 | 
					 | 
					 | 
					 | 
					            saver   = tf.compat.v1.train.Saver()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            init    = tf.global_variables_initializer()
 | 
					 | 
					 | 
					 | 
					            init    = tf.global_variables_initializer()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					            logs = []
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
 | 
					 | 
					 | 
					 | 
					            with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                sess.run(init)
 | 
					 | 
					 | 
					 | 
					                sess.run(init)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                sess.run(iterator_d.initializer,
 | 
					 | 
					 | 
					 | 
					                sess.run(iterator_d.initializer,
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -415,6 +417,10 @@ class Train (GNet):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    format_str = 'epoch: %d, w_distance = %f (%.1f)'
 | 
					 | 
					 | 
					 | 
					                    format_str = 'epoch: %d, w_distance = %f (%.1f)'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
 | 
					 | 
					 | 
					 | 
					                    print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    # print (dir (w_distance))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                    logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                    if epoch % self.MAX_EPOCHS == 0:
 | 
					 | 
					 | 
					 | 
					                    if epoch % self.MAX_EPOCHS == 0:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
					 | 
					 | 
					 | 
					                        # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        suffix = self.get.suffix()
 | 
					 | 
					 | 
					 | 
					                        suffix = self.get.suffix()
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -423,6 +429,10 @@ class Train (GNet):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
 | 
					 | 
					 | 
					 | 
					                        saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        #
 | 
					 | 
					 | 
					 | 
					                        #
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                        #
 | 
					 | 
					 | 
					 | 
					                        #
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                        if self.logger :
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            row = {"logs":logs} #,"model":pickle.dump(sess)}
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                            self.logger.write(row=row)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					class Predict(GNet):
 | 
					 | 
					 | 
					 | 
					class Predict(GNet):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    """
 | 
					 | 
					 | 
					 | 
					    """
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					 | 
					
 
 |