bug fix: checkpoints

dev
Steve Nyemba 2 years ago
parent 4be340ec08
commit 209a7b8ee5

@ -103,7 +103,7 @@ class GNet :
CHECKPOINT_SKIPS = 10
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
CHECKPOINT_SKIPS = 2
self.CHECKPOINTS = np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
@ -529,7 +529,7 @@ class Train (GNet):
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
# saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
saver = tf.compat.v1.train.Saver(max_to_keep=len(self.CHECKPOINTS))
# init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer()
logs = []
@ -587,7 +587,9 @@ class Train (GNet):
tf.compat.v1.reset_default_graph()
#
# let's sort the epochs we've logged thus far (if any)
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
#
# self.logs['epochs'] = self.logs['epochs'][-5:]
self.logs['epochs'].sort(key=lambda _item: _item['loss'])
if self.logger :
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}

Loading…
Cancel
Save