You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

599 lines
26 KiB
Python

"""
self._REAL = args['real'] if 'real' in args else None
self._LABEL = args['label'] if 'label' in args else None
'Z_DIM':self.Z_DIM,
offset = tf.nn.embedding_lookup(offset_m, labels)
scale = tf.nn.embedding_lookup(scale_m, labels)
result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-8)
return result
def _variable_on_cpu(self,**args):
"""
This function makes sure variables/tensors are not created on the GPU but rather on the CPU
"""
name = args['name']
shape = args['shape']
initializer=None if 'initializer' not in args else args['initializer']
with tf.device('/cpu:0') :
cpu_var = tf.compat.v1.get_variable(name,shape,initializer= initializer)
return cpu_var
def average_gradients(self,tower_grads):
average_grads = []
for grad_and_vars in zip(*tower_grads):
grads = []
for g, _ in grad_and_vars:
expanded_g = tf.expand_dims(g, 0)
grads.append(expanded_g)
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
class Generator (GNet):
"""
This class is designed to handle generation of candidate datasets for this it will aggregate a discriminator, this allows the generator not to be random
"""
def __init__(self,**args):
GNet.__init__(self,**args)
self.discriminator = Discriminator(**args)
def loss(self,**args):
fake = args['fake']
label = args['label']
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
kernel = self.get.variables(name='W_' + str(i+1), shape=[self.Z_DIM, self.X_SPACE_SIZE])
x_hat = real + epsilon * (fake - real)
y_hat_fake = self.network(inputs=fake, label=label)
y_hat_real = self.network(inputs=real, label=label)
y_hat = self.network(inputs=x_hat, label=label)
grad = tf.gradients(y_hat, [x_hat])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
:label
def network(self,**args):
# def graph(stage, opt):
# global_step = tf.get_variable(stage+'_step', [], initializer=tf.constant_initializer(0), trainable=False)
stage = args['stage']
opt = args['opt']
tower_grads = []
per_gpu_w = []
iterator, features_placeholder, labels_placeholder = self.input_fn()
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
for i in range(self.NUM_GPUS):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
(real, label) = iterator.get_next()
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
#
# output = []
# for i in range(nbatch):
# f = sess.run(fake,feed_dict={y: label_input[i* self.BATCHSIZE_PER_GPU:(i+1)* self.BATCHSIZE_PER_GPU]})
# output.extend(np.round(f))
# output = np.array(output)[:num]
# print ([m,n,output])
# np.save(self.out_dir + str(m) + str(n), output)
if __name__ == '__main__' :
#
# Now we get things done ...
column = SYS_ARGS['column']
column_id = SYS_ARGS['id'] if 'id' in SYS_ARGS else 'person_id'