checkpoint enhancement

dev
Steve Nyemba 2 years ago
parent 3b0903bd4a
commit ce594634e8

@ -100,13 +100,12 @@ class GNet :
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
CHECKPOINT_SKIPS = 10 CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10)
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : # if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
CHECKPOINT_SKIPS = 2 # CHECKPOINT_SKIPS = 2
self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() # self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist()
self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100 self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100
self.CONTEXT = args['context'] self.CONTEXT = args['context']
self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None} self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None}
@ -469,7 +468,7 @@ class Train (GNet):
else : else :
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
# labels_placeholder = None # labels_placeholder = None
dataset = dataset.repeat(80000) dataset = dataset.repeat(800000)
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU) dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
dataset = dataset.prefetch(1) dataset = dataset.prefetch(1)
@ -560,39 +559,43 @@ class Train (GNet):
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
# print (dir (w_distance)) # print (dir (w_distance))
logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) }) # logs.append({"epoch": int(epoch),"distance":float(-w_sum/(self.STEPS_PER_EPOCH*2)) })
suffix = str(self.CONTEXT)
_name = os.sep.join([self.train_dir,str(epoch),suffix]) if epoch in self.CHECKPOINTS else ''
_logentry = {"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}
# if epoch % self.MAX_EPOCHS == 0: # if epoch % self.MAX_EPOCHS == 0:
# if epoch in [5,10,20,50,75, self.MAX_EPOCHS] : # if epoch in [5,10,20,50,75, self.MAX_EPOCHS] :
if epoch in self.CHECKPOINTS : if epoch in self.CHECKPOINTS :
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
suffix = self.CONTEXT #self.get.suffix() # suffix = self.CONTEXT #self.get.suffix()
_name = os.sep.join([self.train_dir,str(epoch),suffix]) # _name = os.sep.join([self.train_dir,str(epoch),suffix])
# saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch)
saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch)) saver.save(sess, _name, write_meta_graph=False, global_step=np.int64(epoch))
# #
# #
logs = [{"path":_name,"epochs":int(epoch),"loss":float(-w_sum/(self.STEPS_PER_EPOCH*2))}] # logs = []
if self.logger : # if self.logger :
# row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)} # # row = {"module":"gan-train","action":"epochs","input":{"logs":logs}} #,"model":pickle.dump(sess)}
# self.logger.write(row) # # self.logger.write(row)
self.logs['epochs'] += logs # self.logs['epochs'] += logs
# # #
# @TODO: # # @TODO:
# We should upload the files in the checkpoint # # We should upload the files in the checkpoint
# This would allow the learnt model to be portable to another system # # This would allow the learnt model to be portable to another system
# #
self.logs['epochs'].append(_logentry)
tf.compat.v1.reset_default_graph() tf.compat.v1.reset_default_graph()
# #
# let's sort the epochs we've logged thus far (if any) # let's sort the epochs we've logged thus far (if any)
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted # Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
# #
# self.logs['epochs'] = self.logs['epochs'][-5:] # self.logs['epochs'] = self.logs['epochs'][-5:]
self.logs['epochs'].sort(key=lambda _item: _item['loss'])
if self.logger : if self.logger :
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']} _log = {'module':'gan-train','context':self.CONTEXT,'action':'epochs','input':self.logs['epochs']}
self.logger.write(_log) self.logger.write(_log)
# #

@ -226,7 +226,7 @@ class Trainer(Learner):
self.autopilot = _args['autopilot'] if 'autopilot' in _args else False self.autopilot = _args['autopilot'] if 'autopilot' in _args else False
self.generate = None self.generate = None
self.candidates = int(_args['candidates']) if 'candidates' in _args else 1 self.candidates = int(_args['candidates']) if 'candidates' in _args else 1
self.checkpoint_skips = _args['checkpoint_skips'] if 'checkpoint_skips' in _args else None
def run(self): def run(self):
self.initalize() self.initalize()
if self._encoder is None : if self._encoder is None :
@ -242,6 +242,8 @@ class Trainer(Learner):
_args['candidates'] = self.candidates _args['candidates'] = self.candidates
if 'logger' in self.store : if 'logger' in self.store :
_args['logger'] = transport.factory.instance(**self.store['logger']) _args['logger'] = transport.factory.instance(**self.store['logger'])
if self.checkpoint_skips :
_args['checkpoint_skips'] = self.checkpoint_skips
# #
# At this point we have the binary matrix, we can initiate training # At this point we have the binary matrix, we can initiate training
# #
@ -264,8 +266,13 @@ class Trainer(Learner):
_args['gpu'] = self.gpu _args['gpu'] = self.gpu
# #
# Let us find the smallest, the item is sorted by loss ... # Let us find the smallest, the item is sorted by loss on disk
_args['network_args']['max_epochs'] = gTrain.logs['epochs'][0]['epochs'] #
_epochs = [_e for _e in gTrain.logs['epochs'] if _e['path'] != '']
_epochs.sort(key=lambda _item: _item['loss'],reverse=False)
_args['network_args']['max_epochs'] = _epochs[0]['epochs']
self.log(action='autopilot',input={'epoch':_epochs[0]})
g = Generator(**_args) g = Generator(**_args)
# g.run() # g.run()

@ -4,10 +4,10 @@ import sys
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.6.0", args = {"name":"data-maker","version":"1.6.2",
"author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT", "author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow'] args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git' args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'
if sys.version_info[0] == 2 : if sys.version_info[0] == 2 :

Loading…
Cancel
Save