tweak with batch size/gpu (bug with small data)

dev
Steve L. Nyemba 5 years ago
parent cac2dd293d
commit f63ede2fc5

@ -59,8 +59,8 @@ class GNet :
self.logs = {} self.logs = {}
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
if self.NUM_GPUS > 1 : # if self.NUM_GPUS > 1 :
os.environ['CUDA_VISIBLE_DEVICES'] = "4" # os.environ['CUDA_VISIBLE_DEVICES'] = "4"
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE] self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
@ -78,9 +78,12 @@ class GNet :
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM] self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
if 'real' in args : if 'real' in args :
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM] self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
PROPOSED_BATCH_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
# self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256 if args['real'].shape[0] < PROPOSED_BATCH_PER_GPU :
self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size']) self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1)
else:
self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU
# self.BATCHSIZE_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size'])
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
@ -254,7 +257,7 @@ class Generator (GNet):
x = args['inputs'] x = args['inputs']
tmp_dim = self.Z_DIM if 'dim' not in args else args['dim'] tmp_dim = self.Z_DIM if 'dim' not in args else args['dim']
label = args['label'] label = args['label']
print (self.NUM_LABELS)
with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)): with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)):
for i, dim in enumerate(self.G_STRUCTURE[:-1]): for i, dim in enumerate(self.G_STRUCTURE[:-1]):
kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim]) kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])

Loading…
Cancel
Save