bug fix with dimensions @TODO: GPU workload

dev
Steve L. Nyemba 5 years ago
parent 4024e508a8
commit ce55848cc8

@ -59,20 +59,27 @@ class GNet :
self.logs = {} self.logs = {}
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
if self.NUM_GPUS > 1 :
os.environ['CUDA_VISIBLE_DEVICES'] = "4"
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE] self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
self.D_STRUCTURE = [self.X_SPACE_SIZE,256,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE*2, self.X_SPACE_SIZE] #-- change 854 to number of diagnosis self.D_STRUCTURE = [self.X_SPACE_SIZE,256,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE*2, self.X_SPACE_SIZE] #-- change 854 to number of diagnosis
# self.NUM_LABELS = 8 if 'label' not in args elif len(args['label'].shape) args['label'].shape[1] # self.NUM_LABELS = 8 if 'label' not in args elif len(args['label'].shape) args['label'].shape[1]
if 'label' in args and len(args['label'].shape) == 2 : if 'label' in args and len(args['label'].shape) == 2 :
self.NUM_LABELS = args['label'].shape[1] self.NUM_LABELS = args['label'].shape[1]
elif 'label' in args and len(args['label']) == 1 : elif 'label' in args and len(args['label']) == 1 :
self.NUM_LABELS = args['label'].shape[0] self.NUM_LABELS = args['label'].shape[0]
else: else:
self.NUM_LABELS = 8 self.NUM_LABELS = 8
self.Z_DIM = 128 #self.X_SPACE_SIZE # self.Z_DIM = 128 #self.X_SPACE_SIZE
self.BATCHSIZE_PER_GPU = args['real'].shape[0] if 'real' in args else 256 self.Z_DIM = 128 #-- used as rows down stream
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
if 'real' in args :
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
@ -533,6 +540,8 @@ class Predict(GNet):
# The code below will insure we have some acceptable cardinal relationships between id and synthetic values # The code below will insure we have some acceptable cardinal relationships between id and synthetic values
# #
df = ( pd.DataFrame(np.round(f).astype(np.int32))) df = ( pd.DataFrame(np.round(f).astype(np.int32)))
print (df.head())
print ()
p = 0 not in df.sum(axis=1).values p = 0 not in df.sum(axis=1).values
if p: if p:

@ -4,7 +4,7 @@ import sys
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.0.8","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", args = {"name":"data-maker","version":"1.0.9","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo'] args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git' args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'

Loading…
Cancel
Save