From 6de816fc5053386e961af7d92a06367eb0bd57e3 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 3 Jan 2020 21:47:05 -0600 Subject: [PATCH] bug fixes with operations --- data/gan.py | 22 ++++++++++++++++------ data/maker/__init__.py | 27 +++++++++++++++++---------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/data/gan.py b/data/gan.py index e349018..3391b78 100644 --- a/data/gan.py +++ b/data/gan.py @@ -14,6 +14,7 @@ import sys from data.params import SYS_ARGS from data.bridge import Binary import json +import pickle os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = "0" @@ -38,7 +39,7 @@ class GNet : self.layers.normalize = self.normalize - self.NUM_GPUS = 1 + self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 @@ -64,8 +65,8 @@ class GNet : self.get = void() self.get.variables = self._variable_on_cpu - self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] + self.logger = args['logger'] if 'logger' in args and args['logger'] else None self.init_logs(**args) def init_logs(self,**args): @@ -98,7 +99,7 @@ class GNet : def log_meta(self,**args) : - object = { + _object = { 'CONTEXT':self.CONTEXT, 'ATTRIBUTES':self.ATTRIBUTES, 'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU, @@ -120,7 +121,8 @@ class GNet : _name = os.sep.join([self.out_dir,'meta-'+suffix]) f = open(_name+'.json','w') - f.write(json.dumps(object)) + f.write(json.dumps(_object)) + return _object def mkdir (self,path): if not os.path.exists(path) : os.mkdir(path) @@ -295,7 +297,7 @@ class Train (GNet): self.column = args['column'] # print ([" *** ",self.BATCHSIZE_PER_GPU]) - self.log_meta() + self.meta = self.log_meta() def load_meta(self, column): """ This function will delegate the calls to load meta data to it's dependents @@ -393,7 +395,7 @@ class Train (GNet): # saver = tf.train.Saver() saver = tf.compat.v1.train.Saver() init = tf.global_variables_initializer() - + logs = [] with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(init) sess.run(iterator_d.initializer, @@ -415,6 +417,10 @@ class Train (GNet): format_str = 'epoch: %d, w_distance = %f (%.1f)' print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) + # print (dir (w_distance)) + + logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) }) + if epoch % self.MAX_EPOCHS == 0: # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] suffix = self.get.suffix() @@ -423,6 +429,10 @@ class Train (GNet): saver.save(sess, _name, write_meta_graph=False, global_step=epoch) # # + if self.logger : + row = {"logs":logs} #,"model":pickle.dump(sess)} + + self.logger.write(row=row) class Predict(GNet): """ diff --git a/data/maker/__init__.py b/data/maker/__init__.py index 7a441f8..075bfd3 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi import pandas as pd import numpy as np from data import gan - +from transport import factory def train (**args) : """ This function is intended to train the GAN in order to learn about the distribution of the features @@ -21,17 +21,24 @@ def train (**args) : :data data-frame to be synthesized :context label of what we are synthesizing """ - column = args['column'] + column = args['column'] - column_id = args['id'] - df = args['data'] - logs = args['logs'] - real = pd.get_dummies(df[column]).astype(np.float32).values + column_id = args['id'] + df = args['data'] + logs = args['logs'] + real = pd.get_dummies(df[column]).astype(np.float32).values + labels = pd.get_dummies(df[column_id]).astype(np.float32).values - labels = pd.get_dummies(df[column_id]).astype(np.float32).values - max_epochs = 10 - context = args['context'] - trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id) + max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] + context = args['context'] + if 'store' in args : + args['store']['args']['doc'] = context + logger = factory.instance(**args['store']) + + else: + logger = None + + trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) return trainer.apply() def generate(**args):