|
|
|
"""
|
|
|
|
from data.params import SYS_ARGS
|
|
|
|
self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
|
|
|
|
|
|
|
|
"NUM_LABELS":self.NUM_LABELS,
|
|
|
|
os.mkdir(os.sep.join(root))
|
|
|
|
self.discriminator = Discriminator(**_args)
|
|
|
|
def loss(self,**args):
|
|
|
|
fake = args['fake']
|
|
|
|
label = args['label']
|
|
|
|
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
|
|
|
|
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
|
|
|
|
#tf.add_to_collection('glosses', loss)
|
|
|
|
tf.compat.v1.add_to_collection('glosses', loss)
|
|
|
|
return loss, loss
|
|
|
|
"""
|
|
|
|
This function compute the loss of
|
|
|
|
:real
|
|
|
|
:fake
|
|
|
|
:label
|
|
|
|
"""
|
|
|
|
real = args['real']
|
|
|
|
fake = args['fake']
|
|
|
|
label = args['label']
|
|
|
|
epsilon = tf.random.uniform(shape=[self.BATCHSIZE_PER_GPU,1],minval=0,maxval=1)
|
|
|
|
|
|
|
|
x_hat = real + epsilon * (fake - real)
|
|
|
|
y_hat_fake = self.network(inputs=fake, label=label)
|
|
|
|
|
|
|
|
y_hat_real = self.network(inputs=real, label=label)
|
|
|
|
y_hat = self.network(inputs=x_hat, label=label)
|
|
|
|
|
|
|
|
grad = tf.gradients(y_hat, [x_hat])[0]
|
|
|
|
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
|
|
|
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
|
|
|
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
|
|
|
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
|
|
|
|
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
|
|
|
|
#tf.add_to_collection('dlosses', loss)
|
|
|
|
tf.compat.v1.add_to_collection('dlosses', loss)
|
|
|
|
|
|
|
|
return w_distance, loss
|
|
|
|
self.logger.write({"module":"gan-train","action":"start","input":{"partition":self.PARTITION,"meta":self.meta} } )
|
|
|
|
if stage == 'D':
|
|
|
|
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
|
|
|
|
#losses = tf.get_collection('dlosses', scope)
|
|
|
|
flag = 'dlosses'
|
|
|
|
losses = tf.compat.v1.get_collection('dlosses', scope)
|
|
|
|
else:
|
|
|
|
w, loss = self.generator.loss(fake=fake, label=label)
|
|
|
|
#losses = tf.get_collection('glosses', scope)
|
|
|
|
flag = 'glosses'
|
|
|
|
losses = tf.compat.v1.get_collection('glosses', scope)
|
|
|
|
# losses = tf.compat.v1.get_collection(flag, scope)
|
|
|
|
|
|
|
|
total_loss = tf.add_n(losses, name='total_loss')
|
|
|
|
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
|
|
|
|
vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
|
|
|
grads = opt.compute_gradients(loss, vars_)
|
|
|
|
tower_grads.append(grads)
|
|
|
|
per_gpu_w.append(w)
|
|
|
|
|
|
|
|
grads = self.average_gradients(tower_grads)
|
|
|
|
apply_gradient_op = opt.apply_gradients(grads)
|
|
|
|
|
|
|
|
mean_w = tf.reduce_mean(per_gpu_w)
|
|
|
|
train_op = apply_gradient_op
|
|
|
|
return train_op, mean_w, iterator, features_placeholder, labels_placeholder
|
|
|
|
def apply(self,**args):
|
|
|
|
# max_epochs = args['max_epochs'] if 'max_epochs' in args else 10
|
|
|
|
REAL = self._REAL
|
|
|
|
LABEL= self._LABEL
|
|
|
|
if (self.logger):
|
|
|
|
pass
|
|
|
|
|
|
|
|
with tf.device('/cpu:0'):
|
|
|
|
opt_d = tf.compat.v1.train.AdamOptimizer(1e-4)
|
|
|
|
opt_g = tf.compat.v1.train.AdamOptimizer(1e-4)
|
|
|
|
|
|
|
|
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
|
|
|
|
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
|
|
|
# saver = tf.train.Saver()
|
|
|
|
|
|
|
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
|
|
|
_log = {'module':'gan-train','context':self.CONTEXT,'action':'epochs','input':self.logs['epochs']}
|
|
|
|
#
|
|
|
|
# updating the input/output for the generator, so it points properly
|
|
|
|
#
|
|
|
|
|
|
|
|
for object in [self,self.generator] :
|
|
|
|
_train_dir = os.sep.join([self.log_dir,'train',self.CONTEXT,str(self.MAX_EPOCHS)])
|
|
|
|
_out_dir= os.sep.join([self.log_dir,'output',self.CONTEXT,str(self.MAX_EPOCHS)])
|
|
|
|
setattr(object,'train_dir',_train_dir)
|
|
|
|
setattr(object,'out_dir',_out_dir)
|
|
|
|
# df = (i * df).sum(axis=1)
|
|
|
|
# print(df.head())
|