bug fixes with operations

master
Steve L. Nyemba 5 years ago
parent 65a3e84c8f
commit 6de816fc50

@ -14,6 +14,7 @@ import sys
from data.params import SYS_ARGS from data.params import SYS_ARGS
from data.bridge import Binary from data.bridge import Binary
import json import json
import pickle
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0" os.environ['CUDA_VISIBLE_DEVICES'] = "0"
@ -38,7 +39,7 @@ class GNet :
self.layers.normalize = self.normalize self.layers.normalize = self.normalize
self.NUM_GPUS = 1 self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
@ -64,8 +65,8 @@ class GNet :
self.get = void() self.get = void()
self.get.variables = self._variable_on_cpu self.get.variables = self._variable_on_cpu
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
self.init_logs(**args) self.init_logs(**args)
def init_logs(self,**args): def init_logs(self,**args):
@ -98,7 +99,7 @@ class GNet :
def log_meta(self,**args) : def log_meta(self,**args) :
object = { _object = {
'CONTEXT':self.CONTEXT, 'CONTEXT':self.CONTEXT,
'ATTRIBUTES':self.ATTRIBUTES, 'ATTRIBUTES':self.ATTRIBUTES,
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU, 'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
@ -120,7 +121,8 @@ class GNet :
_name = os.sep.join([self.out_dir,'meta-'+suffix]) _name = os.sep.join([self.out_dir,'meta-'+suffix])
f = open(_name+'.json','w') f = open(_name+'.json','w')
f.write(json.dumps(object)) f.write(json.dumps(_object))
return _object
def mkdir (self,path): def mkdir (self,path):
if not os.path.exists(path) : if not os.path.exists(path) :
os.mkdir(path) os.mkdir(path)
@ -295,7 +297,7 @@ class Train (GNet):
self.column = args['column'] self.column = args['column']
# print ([" *** ",self.BATCHSIZE_PER_GPU]) # print ([" *** ",self.BATCHSIZE_PER_GPU])
self.log_meta() self.meta = self.log_meta()
def load_meta(self, column): def load_meta(self, column):
""" """
This function will delegate the calls to load meta data to it's dependents This function will delegate the calls to load meta data to it's dependents
@ -393,7 +395,7 @@ class Train (GNet):
# saver = tf.train.Saver() # saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
logs = []
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init) sess.run(init)
sess.run(iterator_d.initializer, sess.run(iterator_d.initializer,
@ -415,6 +417,10 @@ class Train (GNet):
format_str = 'epoch: %d, w_distance = %f (%.1f)' format_str = 'epoch: %d, w_distance = %f (%.1f)'
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
# print (dir (w_distance))
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
if epoch % self.MAX_EPOCHS == 0: if epoch % self.MAX_EPOCHS == 0:
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
suffix = self.get.suffix() suffix = self.get.suffix()
@ -423,6 +429,10 @@ class Train (GNet):
saver.save(sess, _name, write_meta_graph=False, global_step=epoch) saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
# #
# #
if self.logger :
row = {"logs":logs} #,"model":pickle.dump(sess)}
self.logger.write(row=row)
class Predict(GNet): class Predict(GNet):
""" """

@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from data import gan from data import gan
from transport import factory
def train (**args) : def train (**args) :
""" """
This function is intended to train the GAN in order to learn about the distribution of the features This function is intended to train the GAN in order to learn about the distribution of the features
@ -27,11 +27,18 @@ def train (**args) :
df = args['data'] df = args['data']
logs = args['logs'] logs = args['logs']
real = pd.get_dummies(df[column]).astype(np.float32).values real = pd.get_dummies(df[column]).astype(np.float32).values
labels = pd.get_dummies(df[column_id]).astype(np.float32).values labels = pd.get_dummies(df[column_id]).astype(np.float32).values
max_epochs = 10
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
context = args['context'] context = args['context']
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id) if 'store' in args :
args['store']['args']['doc'] = context
logger = factory.instance(**args['store'])
else:
logger = None
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
return trainer.apply() return trainer.apply()
def generate(**args): def generate(**args):

Loading…
Cancel
Save