You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
288 lines
12 KiB
Python
288 lines
12 KiB
Python
import tensorflow as tf
|
|
from tensorflow.contrib.layers import l2_regularizer
|
|
import numpy as np
|
|
import time
|
|
import os
|
|
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
# os.environ['CUDA_VISIBLE_DEVICES'] = "4,5"
|
|
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|
|
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
|
tf.app.flags.DEFINE_string('train_dir', 'google_cloud_test/',
|
|
"""Directory where to store checkpoint. """)
|
|
tf.app.flags.DEFINE_string('save_dir', 'google_cloud_test/',
|
|
"""Directory where to save generated data. """)
|
|
tf.app.flags.DEFINE_integer('max_steps', 100,
|
|
"""Number of batches to run in each epoch.""")
|
|
tf.app.flags.DEFINE_integer('max_epochs', 100,
|
|
"""Number of epochs to run.""")
|
|
tf.app.flags.DEFINE_integer('batchsize', 10,
|
|
"""Batchsize.""")
|
|
tf.app.flags.DEFINE_integer('z_dim', 10,
|
|
"""Dimensionality of random input.""")
|
|
tf.app.flags.DEFINE_integer('data_dim', 30,
|
|
"""Dimensionality of data.""")
|
|
tf.app.flags.DEFINE_integer('demo_dim', 8,
|
|
"""Dimensionality of demographics.""")
|
|
tf.app.flags.DEFINE_float('reg', 0.0001,
|
|
"""L2 regularization.""")
|
|
|
|
g_structure = [FLAGS.z_dim, FLAGS.z_dim]
|
|
d_structure = [FLAGS.data_dim, int(FLAGS.data_dim/2), FLAGS.z_dim]
|
|
|
|
|
|
def _variable_on_cpu(name, shape, initializer=None):
|
|
with tf.device('/cpu:0'):
|
|
var = tf.get_variable(name, shape, initializer=initializer)
|
|
return var
|
|
|
|
|
|
def batchnorm(inputs, name, labels=None, n_labels=None):
|
|
mean, var = tf.nn.moments(inputs, [0], keep_dims=True)
|
|
shape = mean.shape[1].value
|
|
offset_m = _variable_on_cpu(shape=[n_labels,shape], name='offset'+name,
|
|
initializer=tf.zeros_initializer)
|
|
scale_m = _variable_on_cpu(shape=[n_labels,shape], name='scale'+name,
|
|
initializer=tf.ones_initializer)
|
|
offset = tf.nn.embedding_lookup(offset_m, labels)
|
|
scale = tf.nn.embedding_lookup(scale_m, labels)
|
|
result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-8)
|
|
return result
|
|
|
|
|
|
def layernorm(inputs, name, labels=None, n_labels=None):
|
|
mean, var = tf.nn.moments(inputs, [1], keep_dims=True)
|
|
shape = inputs.shape[1].value
|
|
offset_m = _variable_on_cpu(shape=[n_labels,shape], name='offset'+name,
|
|
initializer=tf.zeros_initializer)
|
|
scale_m = _variable_on_cpu(shape=[n_labels,shape], name='scale'+name,
|
|
initializer=tf.ones_initializer)
|
|
offset = tf.nn.embedding_lookup(offset_m, labels)
|
|
scale = tf.nn.embedding_lookup(scale_m, labels)
|
|
result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-8)
|
|
return result
|
|
|
|
|
|
def input_fn():
|
|
features_placeholder = tf.placeholder(shape=[None, FLAGS.data_dim], dtype=tf.float32)
|
|
labels_placeholder = tf.placeholder(shape=[None, 6], dtype=tf.float32)
|
|
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
|
|
dataset = dataset.repeat(10000)
|
|
dataset = dataset.batch(batch_size=FLAGS.batchsize)
|
|
dataset = dataset.prefetch(1)
|
|
iterator = dataset.make_initializable_iterator()
|
|
return iterator, features_placeholder, labels_placeholder
|
|
|
|
|
|
def generator(z, label):
|
|
x = z
|
|
tmp_dim = FLAGS.z_dim
|
|
with tf.variable_scope('G', reuse=tf.AUTO_REUSE, regularizer=l2_regularizer(FLAGS.reg)):
|
|
for i, dim in enumerate(g_structure[:-1]):
|
|
kernel = _variable_on_cpu('W_' + str(i), shape=[tmp_dim, dim])
|
|
h1 = batchnorm(tf.matmul(x, kernel), name='cbn' + str(i), labels=label, n_labels=FLAGS.demo_dim)
|
|
h2 = tf.nn.relu(h1)
|
|
x = x + h2
|
|
tmp_dim = dim
|
|
i = len(g_structure) - 1
|
|
kernel = _variable_on_cpu('W_' + str(i), shape=[tmp_dim, g_structure[-1]])
|
|
h1 = batchnorm(tf.matmul(x, kernel), name='cbn' + str(i),
|
|
labels=label, n_labels=FLAGS.demo_dim)
|
|
h2 = tf.nn.tanh(h1)
|
|
x = x + h2
|
|
|
|
kernel = _variable_on_cpu('W_' + str(i+1), shape=[FLAGS.z_dim, FLAGS.data_dim])
|
|
bias = _variable_on_cpu('b_' + str(i+1), shape=[FLAGS.data_dim])
|
|
x = tf.nn.sigmoid(tf.add(tf.matmul(x, kernel), bias))
|
|
return x
|
|
|
|
|
|
def discriminator(x, label):
|
|
with tf.variable_scope('D', reuse=tf.AUTO_REUSE, regularizer=l2_regularizer(FLAGS.reg)):
|
|
for i, dim in enumerate(d_structure[1:]):
|
|
kernel = _variable_on_cpu('W_' + str(i), shape=[d_structure[i], dim])
|
|
bias = _variable_on_cpu('b_' + str(i), shape=[dim])
|
|
x = tf.nn.relu(tf.add(tf.matmul(x, kernel), bias))
|
|
x = layernorm(x, name='cln' + str(i), labels=label, n_labels=FLAGS.demo_dim)
|
|
i = len(d_structure)
|
|
kernel = _variable_on_cpu('W_' + str(i), shape=[d_structure[-1], 1])
|
|
bias = _variable_on_cpu('b_' + str(i), shape=[1])
|
|
y = tf.add(tf.matmul(x, kernel), bias)
|
|
return y
|
|
|
|
|
|
def compute_dloss(real, fake, label):
|
|
epsilon = tf.random_uniform(
|
|
shape=[FLAGS.batchsize, 1],
|
|
minval=0.,
|
|
maxval=1.)
|
|
x_hat = real + epsilon * (fake - real)
|
|
y_hat_fake = discriminator(fake, label)
|
|
y_hat_real = discriminator(real, label)
|
|
y_hat = discriminator(x_hat, label)
|
|
|
|
grad = tf.gradients(y_hat, [x_hat])[0]
|
|
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
|
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
|
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)+sum(all_regs)
|
|
loss = w_distance + 10 * gradient_penalty
|
|
tf.add_to_collection('dlosses', loss)
|
|
|
|
return w_distance, loss
|
|
|
|
|
|
def compute_gloss(fake, label):
|
|
y_hat_fake = discriminator(fake, label)
|
|
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
|
loss = -tf.reduce_mean(y_hat_fake)+sum(all_regs)
|
|
tf.add_to_collection('glosses', loss)
|
|
return loss, loss
|
|
|
|
|
|
def tower_loss(scope, stage, real, label):
|
|
label = tf.cast(label, tf.int32)
|
|
print ([stage,label.shape])
|
|
label = label[:, 1] * 4 + tf.squeeze(
|
|
tf.matmul(label[:, 2:], tf.constant([[0], [1], [2], [3]], dtype=tf.int32)))
|
|
z = tf.random_normal(shape=[FLAGS.batchsize, FLAGS.z_dim])
|
|
fake = generator(z, label)
|
|
if stage == 'D':
|
|
w, loss = compute_dloss(real, fake, label)
|
|
losses = tf.get_collection('dlosses', scope)
|
|
else:
|
|
w, loss = compute_gloss(fake, label)
|
|
losses = tf.get_collection('glosses', scope)
|
|
|
|
total_loss = tf.add_n(losses, name='total_loss')
|
|
return total_loss, w
|
|
|
|
|
|
def average_gradients(tower_grads):
|
|
average_grads = []
|
|
for grad_and_vars in zip(*tower_grads):
|
|
grads = []
|
|
for g, _ in grad_and_vars:
|
|
expanded_g = tf.expand_dims(g, 0)
|
|
grads.append(expanded_g)
|
|
|
|
grad = tf.concat(axis=0, values=grads)
|
|
grad = tf.reduce_mean(grad, 0)
|
|
|
|
v = grad_and_vars[0][1]
|
|
grad_and_var = (grad, v)
|
|
average_grads.append(grad_and_var)
|
|
return average_grads
|
|
|
|
|
|
def graph(stage, opt):
|
|
tower_grads = []
|
|
per_gpu_w = []
|
|
iterator, features_placeholder, labels_placeholder = input_fn()
|
|
with tf.variable_scope(tf.get_variable_scope()):
|
|
for i in range(1):
|
|
with tf.device('/cpu:0'):
|
|
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
|
|
(real, label) = iterator.get_next()
|
|
|
|
loss, w = tower_loss(scope, stage, real, label)
|
|
tf.get_variable_scope().reuse_variables()
|
|
vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
|
grads = opt.compute_gradients(loss, vars_)
|
|
tower_grads.append(grads)
|
|
per_gpu_w.append(w)
|
|
|
|
grads = average_gradients(tower_grads)
|
|
apply_gradient_op = opt.apply_gradients(grads)
|
|
|
|
mean_w = tf.reduce_mean(per_gpu_w)
|
|
train_op = apply_gradient_op
|
|
return train_op, mean_w, iterator, features_placeholder, labels_placeholder
|
|
|
|
|
|
def train(data, demo):
|
|
with tf.device('/cpu:0'):
|
|
opt_d = tf.train.AdamOptimizer(1e-4)
|
|
opt_g = tf.train.AdamOptimizer(1e-4)
|
|
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = graph('D', opt_d)
|
|
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = graph('G', opt_g)
|
|
saver = tf.train.Saver()
|
|
init = tf.global_variables_initializer()
|
|
|
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
|
sess.run(init)
|
|
sess.run(iterator_d.initializer,
|
|
feed_dict={features_placeholder_d: data,
|
|
labels_placeholder_d: demo})
|
|
sess.run(iterator_g.initializer,
|
|
feed_dict={features_placeholder_g: data,
|
|
labels_placeholder_g: demo})
|
|
|
|
for epoch in range(1, FLAGS.max_epochs + 1):
|
|
start_time = time.time()
|
|
w_sum = 0
|
|
for i in range(FLAGS.max_steps):
|
|
for _ in range(2):
|
|
_, w = sess.run([train_d, w_distance])
|
|
w_sum += w
|
|
sess.run(train_g)
|
|
duration = time.time() - start_time
|
|
|
|
assert not np.isnan(w_sum), 'Model diverged with loss = NaN'
|
|
|
|
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
|
print(format_str % (epoch, -w_sum/(FLAGS.max_steps*2), duration))
|
|
if epoch % FLAGS.max_epochs == 0:
|
|
# checkpoint_path = os.path.join(train_dir, 'multi')
|
|
saver.save(sess, FLAGS.train_dir + 'emr_wgan', write_meta_graph=False, global_step=epoch)
|
|
# saver.save(sess, train_dir, global_step=epoch)
|
|
|
|
|
|
def generate(demo):
|
|
z = tf.random_normal(shape=[FLAGS.batchsize, FLAGS.z_dim])
|
|
y = tf.placeholder(shape=[FLAGS.batchsize, 6], dtype=tf.int32)
|
|
label = y[:, 1] * 4 + tf.squeeze(tf.matmul(y[:, 2:], tf.constant([[0], [1], [2], [3]], dtype=tf.int32)))
|
|
fake = generator(z, label)
|
|
saver = tf.train.Saver()
|
|
with tf.Session() as sess:
|
|
saver.restore(sess, FLAGS.train_dir + 'emr_wgan-' + str(FLAGS.max_epochs))
|
|
for m in range(2):
|
|
for n in range(2, 6):
|
|
idx1 = (demo[:, m] == 1)
|
|
idx2 = (demo[:, n] == 1)
|
|
idx = [idx1[j] and idx2[j] for j in range(len(idx1))]
|
|
num = np.sum(idx)
|
|
nbatch = int(np.ceil(num / FLAGS.batchsize))
|
|
label_input = np.zeros((nbatch*FLAGS.batchsize, 6))
|
|
label_input[:, n] = 1
|
|
label_input[:, m] = 1
|
|
output = []
|
|
for i in range(nbatch):
|
|
f = sess.run(fake,feed_dict={y: label_input[i*FLAGS.batchsize:(i+1)*FLAGS.batchsize]})
|
|
output.extend(np.round(f))
|
|
output = np.array(output)[:num]
|
|
np.save(FLAGS.save_dir + 'synthetic_' + str(m) + str(n), output)
|
|
|
|
|
|
def load_data():
|
|
data = np.zeros(3000)
|
|
idx = np.random.choice(np.arange(3000),size=900)
|
|
data[idx] = 1
|
|
data = np.reshape(data, (100,30))
|
|
idx = np.random.randint(2,6,size=100)
|
|
idx2 = np.random.randint(2,size=100)
|
|
demo = np.zeros((100,6))
|
|
demo[np.arange(100), idx] = 1
|
|
demo[np.arange(100), idx2] = 1
|
|
return data, demo
|
|
|
|
|
|
if __name__ == '__main__':
|
|
data, demo = load_data()
|
|
print ([data.shape,demo.shape])
|
|
train(data, demo)
|
|
# generate(demo)
|
|
|