|
|
|
@ -100,12 +100,11 @@ class GNet :
|
|
|
|
|
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
|
|
|
|
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
|
|
|
|
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
|
|
|
|
CHECKPOINT_SKIPS = 10
|
|
|
|
|
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
|
|
|
|
CHECKPOINT_SKIPS = 2
|
|
|
|
|
self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10)
|
|
|
|
|
# if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
|
|
|
|
# CHECKPOINT_SKIPS = 2
|
|
|
|
|
# self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
|
|
|
|
self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist()
|
|
|
|
|
|
|
|
|
|
self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100
|
|
|
|
|
self.CONTEXT = args['context']
|
|
|
|
@ -469,7 +468,7 @@ class Train (GNet):
|
|
|
|
|
else :
|
|
|
|
|
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
|
|
|
|
|
# labels_placeholder = None
|
|
|
|
|
dataset = dataset.repeat(80000)
|
|
|
|
|
dataset = dataset.repeat(800000)
|
|
|
|
|
|
|
|
|
|
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
|
|
|
|
|
dataset = dataset.prefetch(1)
|
|
|
|
@ -560,39 +559,43 @@ class Train (GNet):
|
|
|
|
|
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
|
|
|
|
# print (dir (w_distance))
|
|
|
|
|
|
|
|
|
|
logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) })
|
|
|
|
|
# logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) })
|
|
|
|
|
|
|
|
|
|
suffix = str(self.CONTEXT)
|
|
|
|
|
_name = os.sep.join([self.train_dir,str(epoch),suffix]) if epoch in self.CHECKPOINTS else ''
|
|
|
|
|
_logentry = {"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}
|
|
|
|
|
# if epoch % self.MAX_EPOCHS == 0:
|
|
|
|
|
# if epoch in [5,10,20,50,75, self.MAX_EPOCHS] :
|
|
|
|
|
if epoch in self.CHECKPOINTS :
|
|
|
|
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
|
suffix = self.CONTEXT #self.get.suffix()
|
|
|
|
|
_name = os.sep.join([self.train_dir,str(epoch),suffix])
|
|
|
|
|
# suffix = self.CONTEXT #self.get.suffix()
|
|
|
|
|
# _name = os.sep.join([self.train_dir,str(epoch),suffix])
|
|
|
|
|
# saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch)
|
|
|
|
|
saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch))
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
logs = [{"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}]
|
|
|
|
|
if self.logger :
|
|
|
|
|
# row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}
|
|
|
|
|
# self.logger.write(row)
|
|
|
|
|
self.logs['epochs'] += logs
|
|
|
|
|
#
|
|
|
|
|
# @TODO:
|
|
|
|
|
# We should upload the files in the checkpoint
|
|
|
|
|
# This would allow the learnt model to be portable to another system
|
|
|
|
|
# logs = []
|
|
|
|
|
# if self.logger :
|
|
|
|
|
# # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}
|
|
|
|
|
# # self.logger.write(row)
|
|
|
|
|
# self.logs['epochs'] += logs
|
|
|
|
|
# #
|
|
|
|
|
# # @TODO:
|
|
|
|
|
# # We should upload the files in the checkpoint
|
|
|
|
|
# # This would allow the learnt model to be portable to another system
|
|
|
|
|
#
|
|
|
|
|
self.logs['epochs'].append(_logentry)
|
|
|
|
|
tf.compat.v1.reset_default_graph()
|
|
|
|
|
#
|
|
|
|
|
# let's sort the epochs we've logged thus far (if any)
|
|
|
|
|
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
|
|
|
|
|
#
|
|
|
|
|
# self.logs['epochs'] = self.logs['epochs'][-5:]
|
|
|
|
|
self.logs['epochs'].sort(key=lambda _item: _item['loss'])
|
|
|
|
|
|
|
|
|
|
if self.logger :
|
|
|
|
|
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}
|
|
|
|
|
_log = {'module':'gan-train','context':self.CONTEXT,'action':'epochs','input':self.logs['epochs']}
|
|
|
|
|
self.logger.write(_log)
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|