From d469a4904fb5aaa090948ead3172c2d0eeb326f4 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Wed, 9 Nov 2022 14:28:34 -0600 Subject: [PATCH] fixes with new features --- data/gan.py | 2 ++ setup.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/data/gan.py b/data/gan.py index eaf5124..d2cc3ea 100644 --- a/data/gan.py +++ b/data/gan.py @@ -101,6 +101,8 @@ class GNet : self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10) + + CHECKPOINT_SKIPS = 1 if CHECKPOINT_SKIPS < 1 else CHECKPOINT_SKIPS # if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : # CHECKPOINT_SKIPS = 2 # self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() diff --git a/setup.py b/setup.py index 3a2aaba..6327b10 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import sys def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() -args = {"name":"data-maker","version":"1.6.2", +args = {"name":"data-maker","version":"1.6.3", "author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT", "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow']