From 266bdc8bd282ca5b1588434a18f8dcbc3067fb1b Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Sun, 8 Mar 2020 15:00:26 -0500 Subject: [PATCH] bug fix with batch_size (GPU load) --- pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pipeline.py b/pipeline.py index 0f2c258..418ccbf 100644 --- a/pipeline.py +++ b/pipeline.py @@ -78,7 +78,8 @@ class Components : log_folder = os.sep.join([log_folder,args['context'],str(partition)]) _args = {"batch_size":10000,"logs":log_folder,"context":args['context'],"max_epochs":150,"column":args['columns'],"id":"person_id","logger":logger} _args['max_epochs'] = 150 if 'max_epochs' not in args else int(args['max_epochs']) - + if 'batch_size' in args : + _args['batch_size'] = int(args['batch_size']) # # We ask the process to assume 1 gpu given the system number of GPU and that these tasks can run in parallel # @@ -118,6 +119,8 @@ class Components : _args = {"batch_size":2000,"logs":log_folder,"context":args['context'],"max_epochs":150,"column":args['columns'],"id":"person_id","logger":logger} _args['max_epochs'] = 150 if 'max_epochs' not in args else int(args['max_epochs']) # _args['num_gpu'] = int(args['num_gpu']) if 'num_gpu' in args else 1 + if 'batch_size' in args : + _args['batch_size'] = int(args['batch_size']) if int(args['num_gpu']) > 1 : _args['gpu'] = int(args['gpu']) if int(args['gpu']) < 8 else np.random.choice(np.arange(8)).astype(int)[0]