You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

771 lines
43 KiB
Python

"""
5 years ago
from data.params import SYS_ARGS
3 years ago
self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
#tf.add_to_collection('glosses', loss)
tf.compat.v1.add_to_collection('glosses', loss)
return loss, loss
bias = self.get.variables(name='b_' + str(i+1), shape=[self.X_SPACE_SIZE])
x_hat = real + epsilon * (fake - real)
y_hat_fake = self.network(inputs=fake, label=label)
y_hat_real = self.network(inputs=real, label=label)
y_hat = self.network(inputs=x_hat, label=label)
grad = tf.gradients(y_hat, [x_hat])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
#tf.add_to_collection('dlosses', loss)
tf.compat.v1.add_to_collection('dlosses', loss)
return w_distance, loss
losses = tf.compat.v1.get_collection('glosses', scope)
# losses = tf.compat.v1.get_collection(flag, scope)
total_loss = tf.add_n(losses, name='total_loss')
opt = args['opt']
return train_op, mean_w, iterator, features_placeholder, labels_placeholder
def apply(self,**args):
# max_epochs = args['max_epochs'] if 'max_epochs' in args else 10
REAL = self._REAL
LABEL= self._LABEL
if (self.logger):
pass
with tf.device('/cpu:0'):
opt_d = tf.compat.v1.train.AdamOptimizer(1e-4)
opt_g = tf.compat.v1.train.AdamOptimizer(1e-4)
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
# saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
# init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer()
logs = []
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
w_sum += w
# This would allow the learnt model to be portable to another system
#
tf.compat.v1.reset_default_graph()
demo = self._LABEL #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo']
# df = pd.DataFrame(np.round(f)).astype(np.int32)
# if np.divide( np.sum(x), x.size) > .9 or p and np.sum(x) == x.size :