From dcc55eb1fbab75f32f8953d9b150dfe8fd567448 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 10 Jan 2020 13:12:58 -0600 Subject: [PATCH] bug fixes --- data/gan.py | 25 ++++++++++++++++++++----- data/maker/__init__.py | 39 +++++++++++++++++++++++---------------- data/maker/__main__.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 68 insertions(+), 29 deletions(-) diff --git a/data/gan.py b/data/gan.py index 43d15ae..46ecb18 100644 --- a/data/gan.py +++ b/data/gan.py @@ -1,8 +1,23 @@ """ -usage : - optional : - --num_gpu number of gpus to use will default to 1 - --epoch steps per epoch default to 256 +This code was originally writen by Ziqi Zhang in order to generate synthetic data. +The code is an implementation of a Generative Adversarial Network that uses the Wasserstein Distance (WGAN). +It is intended to be used in 2 modes (embedded in code or using CLI) + +USAGE : + +The following parameters should be provided in a configuration file (JSON format) +python data/maker --config + +CONFIGURATION FILE STRUCTURE : + + context what it is you are loading (stroke, hypertension, ...) + data path of the file to be loaded + logs folder to store training model and meta data about learning + max_epochs number of iterations in learning + num_gpu number of gpus to be used (will still run if the GPUs are not available) + +EMBEDDED IN CODE : + """ import tensorflow as tf from tensorflow.contrib.layers import l2_regularizer @@ -426,7 +441,7 @@ class Train (GNet): print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) # print (dir (w_distance)) - logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) }) + logs.append({"epoch":epoch,"distance":-w_sum }) if epoch % self.MAX_EPOCHS == 0: # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] diff --git a/data/maker/__init__.py b/data/maker/__init__.py index f97e5f3..e0ca55d 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -24,21 +24,25 @@ def train (**args) : column = args['column'] column_id = args['id'] - df = args['data'] - logs = args['logs'] - real = pd.get_dummies(df[column]).astype(np.float32).values - labels = pd.get_dummies(df[column_id]).astype(np.float32).values - num_gpu = 1 if 'num_gpu' not in args else args['num_gpu'] - max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] + df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data']) + # logs = args['logs'] + # real = pd.get_dummies(df[column]).astype(np.float32).values + # labels = pd.get_dummies(df[column_id]).astype(np.float32).values + args['real'] = pd.get_dummies(df[column]).astype(np.float32).values + args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values + # num_gpu = 1 if 'num_gpu' not in args else args['num_gpu'] + # max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] context = args['context'] + if 'store' in args : args['store']['args']['doc'] = context logger = factory.instance(**args['store']) + args['logger'] = logger else: logger = None - - trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) + trainer = gan.Train(**args) + # trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) return trainer.apply() def generate(**args): @@ -51,14 +55,14 @@ def generate(**args): :id column identifying an entity :logs location on disk where the learnt knowledge of the dataset is """ - df = args['data'] - + # df = args['data'] + df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data']) column = args['column'] column_id = args['id'] - logs = args['logs'] - context = args['context'] - num_gpu = 1 if 'num_gpu' not in args else args['num_gpu'] - max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] + # logs = args['logs'] + # context = args['context'] + # num_gpu = 1 if 'num_gpu' not in args else args['num_gpu'] + # max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] # #@TODO: @@ -69,8 +73,11 @@ def generate(**args): values = df[column].unique().tolist() values.sort() - labels = pd.get_dummies(df[column_id]).astype(np.float32).values - handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs) + # labels = pd.get_dummies(df[column_id]).astype(np.float32).values + args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values + args['values'] = values + # handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs) + handler = gan.Predict (**args) handler.load_meta(column) r = handler.apply() _df = df.copy() diff --git a/data/maker/__main__.py b/data/maker/__main__.py index e77bf0a..56defec 100644 --- a/data/maker/__main__.py +++ b/data/maker/__main__.py @@ -1,10 +1,27 @@ import pandas as pd import data.maker - -df = pd.read_csv('sample.csv') -column = 'gender' -id = 'id' -context = 'demo' -store = {"type":"mongo.MongoWriter","args":{"host":"localhost:27017","dbname":"GAN"}} -max_epochs = 11 -data.maker.train(store=store,max_epochs=max_epochs,context=context,data=df,column=column,id=id,logs='foo') \ No newline at end of file +from data.params import SYS_ARGS +import json +from scipy.stats import wasserstein_distance as wd +import risk +import numpy as np +if 'config' in SYS_ARGS : + ARGS = json.loads(open(SYS_ARGS['config']).read()) + if 'generate' not in SYS_ARGS : + data.maker.train(**ARGS) + else: + # + # + _df = data.maker.generate(**ARGS) + odf = pd.read_csv (ARGS['data']) + odf.columns = [name.lower() for name in odf.columns] + column = [ARGS['column'] ] #+ ARGS['id'] + print (column) + print (_df[column].risk.evaluate()) + print (odf[column].risk.evaluate()) + _x = pd.get_dummies(_df[column]).values + y = pd.get_dummies(odf[column]).values + N = _df.shape[0] + print (np.mean([ wd(_x[i],y[i])for i in range(0,N)])) + # column = SYS_ARGS['column'] + # odf = open(SYS_ARGS['data']) \ No newline at end of file