From 209a7b8ee5c04f094efa8ef33841e8464fd3f52c Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 16 Sep 2022 22:39:25 -0500 Subject: [PATCH] bug fix: checkpoints --- data/gan.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/data/gan.py b/data/gan.py index f864dbf..dae6ea0 100644 --- a/data/gan.py +++ b/data/gan.py @@ -103,7 +103,7 @@ class GNet : CHECKPOINT_SKIPS = 10 if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : CHECKPOINT_SKIPS = 2 - self.CHECKPOINTS = np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() + self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() @@ -529,7 +529,7 @@ class Train (GNet): train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d) train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g) # saver = tf.train.Saver() - saver = tf.compat.v1.train.Saver() + saver = tf.compat.v1.train.Saver(max_to_keep=len(self.CHECKPOINTS)) # init = tf.global_variables_initializer() init = tf.compat.v1.global_variables_initializer() logs = [] @@ -564,7 +564,7 @@ class Train (GNet): # if epoch % self.MAX_EPOCHS == 0: # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : - if epoch in self.CHECKPOINTS : + if epoch in self.CHECKPOINTS : # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] suffix = self.CONTEXT #self.get.suffix() _name = os.sep.join([self.train_dir,str(epoch),suffix]) @@ -587,7 +587,9 @@ class Train (GNet): tf.compat.v1.reset_default_graph() # # let's sort the epochs we've logged thus far (if any) + # Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted # + # self.logs['epochs'] = self.logs['epochs'][-5:] self.logs['epochs'].sort(key=lambda _item: _item['loss']) if self.logger : _log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}