|
|
@ -14,6 +14,7 @@ import sys
|
|
|
|
from data.params import SYS_ARGS
|
|
|
|
from data.params import SYS_ARGS
|
|
|
|
from data.bridge import Binary
|
|
|
|
from data.bridge import Binary
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
|
|
@ -38,7 +39,7 @@ class GNet :
|
|
|
|
self.layers.normalize = self.normalize
|
|
|
|
self.layers.normalize = self.normalize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.NUM_GPUS = 1
|
|
|
|
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
|
|
|
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
|
|
@ -64,8 +65,8 @@ class GNet :
|
|
|
|
|
|
|
|
|
|
|
|
self.get = void()
|
|
|
|
self.get = void()
|
|
|
|
self.get.variables = self._variable_on_cpu
|
|
|
|
self.get.variables = self._variable_on_cpu
|
|
|
|
|
|
|
|
|
|
|
|
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
|
|
|
|
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
|
|
|
|
self.init_logs(**args)
|
|
|
|
self.init_logs(**args)
|
|
|
|
|
|
|
|
|
|
|
|
def init_logs(self,**args):
|
|
|
|
def init_logs(self,**args):
|
|
|
@ -98,7 +99,7 @@ class GNet :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_meta(self,**args) :
|
|
|
|
def log_meta(self,**args) :
|
|
|
|
object = {
|
|
|
|
_object = {
|
|
|
|
'CONTEXT':self.CONTEXT,
|
|
|
|
'CONTEXT':self.CONTEXT,
|
|
|
|
'ATTRIBUTES':self.ATTRIBUTES,
|
|
|
|
'ATTRIBUTES':self.ATTRIBUTES,
|
|
|
|
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
|
|
|
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
|
|
@ -120,7 +121,8 @@ class GNet :
|
|
|
|
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
|
|
|
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
|
|
|
|
|
|
|
|
|
|
|
f = open(_name+'.json','w')
|
|
|
|
f = open(_name+'.json','w')
|
|
|
|
f.write(json.dumps(object))
|
|
|
|
f.write(json.dumps(_object))
|
|
|
|
|
|
|
|
return _object
|
|
|
|
def mkdir (self,path):
|
|
|
|
def mkdir (self,path):
|
|
|
|
if not os.path.exists(path) :
|
|
|
|
if not os.path.exists(path) :
|
|
|
|
os.mkdir(path)
|
|
|
|
os.mkdir(path)
|
|
|
@ -295,7 +297,7 @@ class Train (GNet):
|
|
|
|
self.column = args['column']
|
|
|
|
self.column = args['column']
|
|
|
|
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
|
|
|
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
|
|
|
|
|
|
|
|
|
|
|
self.log_meta()
|
|
|
|
self.meta = self.log_meta()
|
|
|
|
def load_meta(self, column):
|
|
|
|
def load_meta(self, column):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function will delegate the calls to load meta data to it's dependents
|
|
|
|
This function will delegate the calls to load meta data to it's dependents
|
|
|
@ -393,7 +395,7 @@ class Train (GNet):
|
|
|
|
# saver = tf.train.Saver()
|
|
|
|
# saver = tf.train.Saver()
|
|
|
|
saver = tf.compat.v1.train.Saver()
|
|
|
|
saver = tf.compat.v1.train.Saver()
|
|
|
|
init = tf.global_variables_initializer()
|
|
|
|
init = tf.global_variables_initializer()
|
|
|
|
|
|
|
|
logs = []
|
|
|
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
|
|
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
|
|
|
sess.run(init)
|
|
|
|
sess.run(init)
|
|
|
|
sess.run(iterator_d.initializer,
|
|
|
|
sess.run(iterator_d.initializer,
|
|
|
@ -415,6 +417,10 @@ class Train (GNet):
|
|
|
|
|
|
|
|
|
|
|
|
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
|
|
|
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
|
|
|
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
|
|
|
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
|
|
|
|
|
|
|
# print (dir (w_distance))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
|
|
|
|
|
|
|
|
|
|
|
|
if epoch % self.MAX_EPOCHS == 0:
|
|
|
|
if epoch % self.MAX_EPOCHS == 0:
|
|
|
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
suffix = self.get.suffix()
|
|
|
|
suffix = self.get.suffix()
|
|
|
@ -423,6 +429,10 @@ class Train (GNet):
|
|
|
|
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
|
|
|
|
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
#
|
|
|
|
#
|
|
|
|
|
|
|
|
if self.logger :
|
|
|
|
|
|
|
|
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.write(row=row)
|
|
|
|
|
|
|
|
|
|
|
|
class Predict(GNet):
|
|
|
|
class Predict(GNet):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|