|
|
|
@ -103,7 +103,7 @@ class GNet :
|
|
|
|
|
CHECKPOINT_SKIPS = 10
|
|
|
|
|
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
|
|
|
|
CHECKPOINT_SKIPS = 2
|
|
|
|
|
self.CHECKPOINTS = np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
|
|
|
|
self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -529,7 +529,7 @@ class Train (GNet):
|
|
|
|
|
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
|
|
|
|
|
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
|
|
|
|
# saver = tf.train.Saver()
|
|
|
|
|
saver = tf.compat.v1.train.Saver()
|
|
|
|
|
saver = tf.compat.v1.train.Saver(max_to_keep=len(self.CHECKPOINTS))
|
|
|
|
|
# init = tf.global_variables_initializer()
|
|
|
|
|
init = tf.compat.v1.global_variables_initializer()
|
|
|
|
|
logs = []
|
|
|
|
@ -564,7 +564,7 @@ class Train (GNet):
|
|
|
|
|
|
|
|
|
|
# if epoch % self.MAX_EPOCHS == 0:
|
|
|
|
|
# if epoch in [5,10,20,50,75, self.MAX_EPOCHS] :
|
|
|
|
|
if epoch in self.CHECKPOINTS :
|
|
|
|
|
if epoch in self.CHECKPOINTS :
|
|
|
|
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
|
suffix = self.CONTEXT #self.get.suffix()
|
|
|
|
|
_name = os.sep.join([self.train_dir,str(epoch),suffix])
|
|
|
|
@ -587,7 +587,9 @@ class Train (GNet):
|
|
|
|
|
tf.compat.v1.reset_default_graph()
|
|
|
|
|
#
|
|
|
|
|
# let's sort the epochs we've logged thus far (if any)
|
|
|
|
|
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
|
|
|
|
|
#
|
|
|
|
|
# self.logs['epochs'] = self.logs['epochs'][-5:]
|
|
|
|
|
self.logs['epochs'].sort(key=lambda _item: _item['loss'])
|
|
|
|
|
if self.logger :
|
|
|
|
|
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}
|
|
|
|
|