removing conditions, it blows up computational space

dev
Steve L. Nyemba 5 years ago
parent dab3ab7bf7
commit 4a25af6b13

@ -72,7 +72,7 @@ class GNet :
elif 'label' in args and len(args['label']) == 1 :
self.NUM_LABELS = args['label'].shape[0]
else:
self.NUM_LABELS = 8
self.NUM_LABELS = None
# self.Z_DIM = 128 #self.X_SPACE_SIZE
self.Z_DIM = 128 #-- used as rows down stream
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
@ -180,14 +180,19 @@ class GNet :
shift = [0] if self.__class__.__name__.lower() == 'generator' else [1] #-- not sure what this is doing
mean, var = tf.nn.moments(inputs, shift, keep_dims=True)
shape = inputs.shape[1].value
offset_m = self.get.variables(shape=[n_labels,shape], name='offset'+name,
if labels is not None:
offset_m = self.get.variables(shape=[1,shape], name='offset'+name,
initializer=tf.zeros_initializer)
scale_m = self.get.variables(shape=[n_labels,shape], name='scale'+name,
initializer=tf.ones_initializer)
offset = tf.nn.embedding_lookup(offset_m, labels)
scale = tf.nn.embedding_lookup(scale_m, labels)
result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-8)
else:
offset = None
scale = None
result = tf.nn.batch_normalization(inputs, mean, var,offset,scale, 1e-8)
return result
def _variable_on_cpu(self,**args):
@ -248,7 +253,7 @@ class Generator (GNet):
x = args['inputs']
tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
label = args['label']
print (self.NUM_LABELS)
with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
for i, dim in enumerate(self.G_STRUCTURE[:-1]):
kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])
@ -331,7 +336,7 @@ class Train (GNet):
self.generator = Generator(**args)
self.discriminator = Discriminator(**args)
self._REAL = args['real']
self._LABEL= args['label']
self._LABEL= args['label'] if 'label' in args else None
self.column = args['column']
# print ([" *** ",self.BATCHSIZE_PER_GPU])
@ -340,7 +345,7 @@ class Train (GNet):
self.logger.write( self.meta )
self.log (real_shape=list(self._REAL.shape),label_shape = list(self._LABEL.shape),meta_data=self.meta)
# self.log (real_shape=list(self._REAL.shape),label_shape = self._LABEL.shape,meta_data=self.meta)
def load_meta(self, column):
"""
This function will delegate the calls to load meta data to it's dependents
@ -363,6 +368,9 @@ class Train (GNet):
stage = args['stage']
real = args['real']
label = args['label']
if label is not None :
label = tf.cast(label, tf.int32)
#
# @TODO: Ziqi needs to explain what's going on here
@ -394,8 +402,13 @@ class Train (GNet):
This function seems to produce
"""
features_placeholder = tf.compat.v1.placeholder(shape=self._REAL.shape, dtype=tf.float32)
labels_placeholder = tf.compat.v1.placeholder(shape=self._LABEL.shape, dtype=tf.float32)
LABEL_SHAPE = [None,None] if self._LABEL is None else self._LABEL.shape
labels_placeholder = tf.compat.v1.placeholder(shape=LABEL_SHAPE, dtype=tf.float32)
if self._LABEL is not None :
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
else :
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
# labels_placeholder = None
dataset = dataset.repeat(10000)
dataset = dataset.batch(batch_size=3000)
dataset = dataset.prefetch(1)
@ -413,7 +426,10 @@ class Train (GNet):
for i in range(self.NUM_GPUS):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
if self._LABEL is not None :
(real, label) = iterator.get_next()
else:
real = iterator.get_next()
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
#tf.get_variable_scope().reuse_variables()
tf.compat.v1.get_variable_scope().reuse_variables()
@ -450,10 +466,11 @@ class Train (GNet):
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init)
sess.run(iterator_d.initializer,
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})
feed_dict={features_placeholder_d: REAL})
sess.run(iterator_g.initializer,
feed_dict={features_placeholder_g: REAL, labels_placeholder_g: LABEL})
feed_dict={features_placeholder_g: REAL})
for epoch in range(1, self.MAX_EPOCHS + 1):
start_time = time.time()
@ -511,9 +528,11 @@ class Predict(GNet):
tf.compat.v1.reset_default_graph()
z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM])
y = tf.compat.v1.placeholder(shape=[self.BATCHSIZE_PER_GPU, self.NUM_LABELS], dtype=tf.int32)
if self._LABEL is not None :
ma = [[i] for i in np.arange(self.NUM_LABELS - 2)]
label = y[:, 1] * len(ma) + tf.squeeze(tf.matmul(y[:, 2:], tf.constant(ma, dtype=tf.int32)))
else:
label = None
fake = self.generator.network(inputs=z, label=label)
init = tf.compat.v1.global_variables_initializer()
saver = tf.compat.v1.train.Saver()
@ -524,13 +543,19 @@ class Predict(GNet):
# sess.run(init)
saver.restore(sess, model_dir)
if self._LABEL is not None :
labels = np.zeros((self.ROW_COUNT,self.NUM_LABELS) )
labels= demo
else:
labels = None
found = []
labels= demo
for i in np.arange(CANDIDATE_COUNT) :
for i in np.arange(CANDIDATE_COUNT) :
if labels :
f = sess.run(fake,feed_dict={y:labels})
else:
f = sess.run(fake)
#
# if we are dealing with numeric values only we can perform a simple marginal sum against the indexes
# The code below will insure we have some acceptable cardinal relationships between id and synthetic values

@ -25,7 +25,7 @@ def train (**args) :
"""
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
column_id = args['id']
# column_id = args['id']
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
df.columns = [name.lower() for name in df.columns]
@ -35,7 +35,8 @@ def train (**args) :
#
handler = Binary()
# args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
args['label'] = handler.Export(df[[column_id]])
# args['label'] = handler.Export(df[[column_id]])
# args['label'] = np.ones(df.shape[0]).reshape(df.shape[0],1)
for col in column :
# args['real'] = pd.get_dummies(df[col]).astype(np.float32).values
args['real'] = handler.Export(df[[col]])
@ -83,7 +84,7 @@ def generate(**args):
#
# args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
bwrangler = Binary()
args['label'] = bwrangler.Export(df[[column_id]])
# args['label'] = bwrangler.Export(df[[column_id]])
_df = df.copy()
for col in column :
args['context'] = col

@ -7,7 +7,7 @@ def read(fname):
args = {"name":"data-maker","version":"1.1.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'
if sys.version_info[0] == 2 :
args['use_2to3'] = False

Loading…
Cancel
Save