From f85d5fd87054266332f2e0e6142cf200f53eef41 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 10 Jan 2020 09:53:23 -0600 Subject: [PATCH] bug fix with number of GPU, columns as identifiers --- data/gan.py | 6 ++---- data/maker/__init__.py | 9 ++++++--- setup.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/data/gan.py b/data/gan.py index 7bd17ee..43d15ae 100644 --- a/data/gan.py +++ b/data/gan.py @@ -245,15 +245,12 @@ class Discriminator(GNet): :label """ x = args['inputs'] - print () - print (x[:3,:]) - print() label = args['label'] with tf.compat.v1.variable_scope('D', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)): for i, dim in enumerate(self.D_STRUCTURE[1:]): kernel = self.get.variables(name='W_' + str(i), shape=[self.D_STRUCTURE[i], dim]) bias = self.get.variables(name='b_' + str(i), shape=[dim]) - print (["\t",bias,kernel]) + # print (["\t",bias,kernel]) x = tf.nn.relu(tf.add(tf.matmul(x, kernel), bias)) x = self.normalize(inputs=x, name='cln' + str(i), shift=1,labels=label, n_labels=self.NUM_LABELS) i = len(self.D_STRUCTURE) @@ -538,6 +535,7 @@ if __name__ == '__main__' : # Now we get things done ... column = SYS_ARGS['column'] column_id = SYS_ARGS['id'] if 'id' in SYS_ARGS else 'person_id' + column_id = column_id.split(',') if ',' in column_id else column_id df = pd.read_csv(SYS_ARGS['raw-data']) LABEL = pd.get_dummies(df[column_id]).astype(np.float32).values diff --git a/data/maker/__init__.py b/data/maker/__init__.py index 967dbc8..f97e5f3 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -38,7 +38,7 @@ def train (**args) : else: logger = None - trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) + trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) return trainer.apply() def generate(**args): @@ -57,6 +57,9 @@ def generate(**args): column_id = args['id'] logs = args['logs'] context = args['context'] + num_gpu = 1 if 'num_gpu' not in args else args['num_gpu'] + max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] + # #@TODO: # If the identifier is not present, we should fine a way to determine or make one @@ -67,9 +70,9 @@ def generate(**args): values.sort() labels = pd.get_dummies(df[column_id]).astype(np.float32).values - handler = gan.Predict (context=context,label=labels,values=values,column=column) + handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs) handler.load_meta(column) r = handler.apply() _df = df.copy() _df[column] = r[column] - return _df + return _df \ No newline at end of file diff --git a/setup.py b/setup.py index beda18e..db4029b 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import sys def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() -args = {"name":"data-maker","version":"1.0.3","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", +args = {"name":"data-maker","version":"1.0.5","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo'] args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'