|
|
@ -59,8 +59,8 @@ class GNet :
|
|
|
|
self.logs = {}
|
|
|
|
self.logs = {}
|
|
|
|
|
|
|
|
|
|
|
|
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
|
|
|
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
|
|
|
if self.NUM_GPUS > 1 :
|
|
|
|
# if self.NUM_GPUS > 1 :
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "4"
|
|
|
|
# os.environ['CUDA_VISIBLE_DEVICES'] = "4"
|
|
|
|
|
|
|
|
|
|
|
|
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
|
|
|
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
|
|
|
self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
|
|
|
|
self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
|
|
|
@ -78,9 +78,12 @@ class GNet :
|
|
|
|
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
|
|
|
|
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
|
|
|
|
if 'real' in args :
|
|
|
|
if 'real' in args :
|
|
|
|
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
|
|
|
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
|
|
|
|
|
|
|
PROPOSED_BATCH_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
|
|
|
|
# self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
|
|
|
if args['real'].shape[0] < PROPOSED_BATCH_PER_GPU :
|
|
|
|
self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size'])
|
|
|
|
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
|
|
|
|
|
|
|
|
# self.BATCHSIZE_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
|
|
|
|
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
|
|
|
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
|
|
|
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
|
|
|
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
|
|
|
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
|
|
|
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
|
|
@ -254,7 +257,7 @@ class Generator (GNet):
|
|
|
|
x = args['inputs']
|
|
|
|
x = args['inputs']
|
|
|
|
tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
|
|
|
|
tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
|
|
|
|
label = args['label']
|
|
|
|
label = args['label']
|
|
|
|
print (self.NUM_LABELS)
|
|
|
|
|
|
|
|
with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
|
|
|
|
with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
|
|
|
|
for i, dim in enumerate(self.G_STRUCTURE[:-1]):
|
|
|
|
for i, dim in enumerate(self.G_STRUCTURE[:-1]):
|
|
|
|
kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])
|
|
|
|
kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])
|
|
|
|