From f63ede2fc58c983635b4c5a89ef33031938232d9 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Tue, 18 Feb 2020 17:23:13 -0600 Subject: [PATCH] tweak with batch size/gpu (bug with small data) --- data/gan.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/data/gan.py b/data/gan.py index c18277c..ed8facd 100644 --- a/data/gan.py +++ b/data/gan.py @@ -59,8 +59,8 @@ class GNet : self.logs = {} self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] - if self.NUM_GPUS > 1 : - os.environ['CUDA_VISIBLE_DEVICES'] = "4" + # if self.NUM_GPUS > 1 : + # os.environ['CUDA_VISIBLE_DEVICES'] = "4" self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE] @@ -78,9 +78,12 @@ class GNet : self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM] if 'real' in args : self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM] - - # self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256 - self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size']) + PROPOSED_BATCH_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size']) + if args['real'].shape[0] < PROPOSED_BATCH_PER_GPU : + self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) + else: + self.BATCHSIZE_PER_GPU = PROPOSED_BATCH_PER_GPU + # self.BATCHSIZE_PER_GPU = 2000 if 'batch_size' not in args else int(args['batch_size']) self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) @@ -254,7 +257,7 @@ class Generator (GNet): x = args['inputs'] tmp_dim = self.Z_DIM if 'dim' not in args else args['dim'] label = args['label'] - print (self.NUM_LABELS) + with tf.compat.v1.variable_scope('G', reuse=tf.compat.v1.AUTO_REUSE , regularizer=l2_regularizer(0.00001)): for i, dim in enumerate(self.G_STRUCTURE[:-1]): kernel = self.get.variables(name='W_' + str(i), shape=[tmp_dim, dim])