From ce594634e848a1956a5ff3dbd2c08a34028592de Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Tue, 11 Oct 2022 18:18:59 -0500 Subject: [PATCH] checkpoint enhancement --- data/gan.py | 49 ++++++++++++++++++++++-------------------- data/maker/__init__.py | 13 ++++++++--- setup.py | 4 ++-- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/data/gan.py b/data/gan.py index dae6ea0..eaf5124 100644 --- a/data/gan.py +++ b/data/gan.py @@ -100,13 +100,12 @@ class GNet : self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) - CHECKPOINT_SKIPS = 10 - if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : - CHECKPOINT_SKIPS = 2 - self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() - - - + CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10) + # if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : + # CHECKPOINT_SKIPS = 2 + # self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() + self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist() + self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100 self.CONTEXT = args['context'] self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None} @@ -469,7 +468,7 @@ class Train (GNet): else : dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) # labels_placeholder = None - dataset = dataset.repeat(80000) + dataset = dataset.repeat(800000) dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU) dataset = dataset.prefetch(1) @@ -560,39 +559,43 @@ class Train (GNet): print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) # print (dir (w_distance)) - logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) }) - + # logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) }) + + suffix = str(self.CONTEXT) + _name = os.sep.join([self.train_dir,str(epoch),suffix]) if epoch in self.CHECKPOINTS else '' + _logentry = {"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))} # if epoch % self.MAX_EPOCHS == 0: # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : if epoch in self.CHECKPOINTS : # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] - suffix = self.CONTEXT #self.get.suffix() - _name = os.sep.join([self.train_dir,str(epoch),suffix]) + # suffix = self.CONTEXT #self.get.suffix() + # _name = os.sep.join([self.train_dir,str(epoch),suffix]) # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch)) # # - logs = [{"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}] - if self.logger : - # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)} - # self.logger.write(row) - self.logs['epochs'] += logs - # - # @TODO: - # We should upload the files in the checkpoint - # This would allow the learnt model to be portable to another system + # logs = [] + # if self.logger : + # # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)} + # # self.logger.write(row) + # self.logs['epochs'] += logs + # # + # # @TODO: + # # We should upload the files in the checkpoint + # # This would allow the learnt model to be portable to another system # + self.logs['epochs'].append(_logentry) tf.compat.v1.reset_default_graph() # # let's sort the epochs we've logged thus far (if any) # Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted # # self.logs['epochs'] = self.logs['epochs'][-5:] - self.logs['epochs'].sort(key=lambda _item: _item['loss']) + if self.logger : - _log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']} + _log = {'module':'gan-train','context':self.CONTEXT,'action':'epochs','input':self.logs['epochs']} self.logger.write(_log) # diff --git a/data/maker/__init__.py b/data/maker/__init__.py index 7f9c0f6..fdf2305 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -226,7 +226,7 @@ class Trainer(Learner): self.autopilot = _args['autopilot'] if 'autopilot' in _args else False self.generate = None self.candidates = int(_args['candidates']) if 'candidates' in _args else 1 - + self.checkpoint_skips = _args['checkpoint_skips'] if 'checkpoint_skips' in _args else None def run(self): self.initalize() if self._encoder is None : @@ -242,6 +242,8 @@ class Trainer(Learner): _args['candidates'] = self.candidates if 'logger' in self.store : _args['logger'] = transport.factory.instance(**self.store['logger']) + if self.checkpoint_skips : + _args['checkpoint_skips'] = self.checkpoint_skips # # At this point we have the binary matrix, we can initiate training # @@ -264,8 +266,13 @@ class Trainer(Learner): _args['gpu'] = self.gpu # - # Let us find the smallest, the item is sorted by loss ... - _args['network_args']['max_epochs'] = gTrain.logs['epochs'][0]['epochs'] + # Let us find the smallest, the item is sorted by loss on disk + # + _epochs = [_e for _e in gTrain.logs['epochs'] if _e['path'] != ''] + _epochs.sort(key=lambda _item: _item['loss'],reverse=False) + + _args['network_args']['max_epochs'] = _epochs[0]['epochs'] + self.log(action='autopilot',input={'epoch':_epochs[0]}) g = Generator(**_args) # g.run() diff --git a/setup.py b/setup.py index c28f366..3a2aaba 100644 --- a/setup.py +++ b/setup.py @@ -4,10 +4,10 @@ import sys def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() -args = {"name":"data-maker","version":"1.6.0", +args = {"name":"data-maker","version":"1.6.2", "author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT", "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} -args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow'] +args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow'] args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git' if sys.version_info[0] == 2 :