|
|
|
"""
|
|
|
|
import pickle
|
|
|
|
self.NUM_LABELS = args['label'].shape[1]
|
|
|
|
self.init_logs(**args)
|
|
|
|
|
|
|
|
def init_logs(self,**args):
|
|
|
|
self.log_dir = args['logs'] if 'logs' in args else 'logs'
|
|
|
|
self.mkdir(self.log_dir)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
for key in ['train','output'] :
|
|
|
|
self.mkdir(os.sep.join([self.log_dir,key]))
|
|
|
|
self.mkdir (os.sep.join([self.log_dir,key,self.CONTEXT]))
|
|
|
|
|
|
|
|
self.train_dir = os.sep.join([self.log_dir,'train',self.CONTEXT])
|
|
|
|
self.out_dir = os.sep.join([self.log_dir,'output',self.CONTEXT])
|
|
|
|
if self.logger :
|
|
|
|
#
|
|
|
|
# We will clear the logs from the data-store
|
|
|
|
#
|
|
|
|
column = self.ATTRIBUTES['synthetic']
|
|
|
|
db = self.logger.db
|
|
|
|
if db[column].count() > 0 :
|
|
|
|
db.backup.insert({'name':column,'logs':list(db[column].find()) })
|
|
|
|
db[column].drop()
|
|
|
|
|
|
|
|
def load_meta(self,column):
|
|
|
|
"""
|
|
|
|
This function is designed to accomodate the uses of the sub-classes outside of a strict dependency model.
|
|
|
|
Because prediction and training can happen independently
|
|
|
|
"""
|
|
|
|
# suffix = "-".join(column) if isinstance(column,list)else column
|
|
|
|
suffix = self.get.suffix()
|
|
|
|
_name = os.sep.join([self.out_dir,'meta-'+suffix+'.json'])
|
|
|
|
if os.path.exists(_name) :
|
|
|
|
attr = json.loads((open(_name)).read())
|
|
|
|
for key in attr :
|
|
|
|
value = attr[key]
|
|
|
|
setattr(self,key,value)
|
|
|
|
self.train_dir = os.sep.join([self.log_dir,'train',self.CONTEXT])
|
|
|
|
self.out_dir = os.sep.join([self.log_dir,'output',self.CONTEXT])
|
|
|
|
|
|
|
|
|
|
|
|
def log_meta(self,**args) :
|
|
|
|
|
|
|
|
_object = {
|
|
|
|
# '_id':'meta',
|
|
|
|
'CONTEXT':self.CONTEXT,
|
|
|
|
'ATTRIBUTES':self.ATTRIBUTES,
|
|
|
|
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
|
|
|
'Z_DIM':self.Z_DIM,
|
|
|
|
"X_SPACE_SIZE":self.X_SPACE_SIZE,
|
|
|
|
"D_STRUCTURE":self.D_STRUCTURE,
|
|
|
|
"G_STRUCTURE":self.G_STRUCTURE,
|
|
|
|
"NUM_GPUS":self.NUM_GPUS,
|
|
|
|
"NUM_LABELS":self.NUM_LABELS,
|
|
|
|
"MAX_EPOCHS":self.MAX_EPOCHS,
|
|
|
|
"ROW_COUNT":self.ROW_COUNT
|
|
|
|
}
|
|
|
|
if args and 'key' in args and 'value' in args :
|
|
|
|
key = args['key']
|
|
|
|
value= args['value']
|
|
|
|
object[key] = value
|
|
|
|
# suffix = "-".join(self.column) if isinstance(self.column,list) else self.column
|
|
|
|
suffix = self.get.suffix()
|
|
|
|
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
|
|
|
|
|
|
|
f = open(_name+'.json','w')
|
|
|
|
f.write(json.dumps(_object))
|
|
|
|
return _object
|
|
|
|
def mkdir (self,path):
|
|
|
|
if not os.path.exists(path) :
|
|
|
|
os.mkdir(os.sep.join(root))
|
|
|
|
|
|
|
|
"""
|
|
|
|
def load_meta(self, column):
|
|
|
|
super().load_meta(column)
|
|
|
|
self.discriminator.load_meta(column)
|
|
|
|
def network(self,**args) :
|
|
|
|
"""
|
|
|
|
This function will build the network that will generate the synthetic candidates
|
|
|
|
:inputs matrix of data that we need
|
|
|
|
:dim dimensions of ...
|
|
|
|
"""
|
|
|
|
x = args['inputs']
|
|
|
|
tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
|
|
|
|
label = args['label']
|
|
|
|
y_hat = self.network(inputs=x_hat, label=label)
|
|
|
|
|
|
|
|
grad = tf.gradients(y_hat, [x_hat])[0]
|
|
|
|
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
|
|
|
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
|
|
|
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
|
|
|
|
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
|
|
|
|
#tf.add_to_collection('dlosses', loss)
|
|
|
|
tf.compat.v1.add_to_collection('dlosses', loss)
|
|
|
|
|
|
|
|
return w_distance, loss
|
|
|
|
Training will consist in having both generator and discriminators
|
|
|
|
"""
|
|
|
|
features_placeholder = tf.compat.v1.placeholder(shape=self._REAL.shape, dtype=tf.float32)
|
|
|
|
#tf.get_variable_scope().reuse_variables()
|
|
|
|
sess.run(init)
|
|
|
|
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
|
|
|
fake = self.generator.network(inputs=z, label=label)
|
|
|
|
init = tf.compat.v1.global_variables_initializer()
|
|
|
|
saver = tf.compat.v1.train.Saver()
|
|
|
|
df = pd.DataFrame()
|
|
|
|
CANDIDATE_COUNT = 1000
|
|
|
|
NTH_VALID_CANDIDATE = count = np.random.choice(np.arange(2,60),2)[0]
|
|
|
|
with tf.compat.v1.Session() as sess:
|
|
|
|
|
|
|
|
# sess.run(init)
|
|
|
|
saver.restore(sess, model_dir)
|
|
|
|
df = pd.DataFrame( df.iloc[i].apply(lambda row: self.values[np.random.choice(np.where(row != 0)[0],1)[0]] ,axis=1))
|
|
|
|
df.columns = columns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
# We should train upon this data
|