From 832581303b623c4cc6cc3cf43f4716ac1427f773 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Wed, 4 Mar 2020 14:08:10 -0600 Subject: [PATCH] bug fix: gpu assignement error --- pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pipeline.py b/pipeline.py index 8c8a7d7..b53ba52 100644 --- a/pipeline.py +++ b/pipeline.py @@ -63,6 +63,7 @@ class Components : _args = {"batch_size":10000,"logs":log_folder,"context":args['context'],"max_epochs":150,"column":args['columns'],"id":"person_id","logger":logger} _args['max_epochs'] = 150 if 'max_epochs' not in args else int(args['max_epochs']) _args['num_gpu'] = int(args['num_gpu']) if 'num_gpu' in args else 1 + _args['gpu'] = args['gpu'] if 'gpu' in args else 0 MAX_ROWS = args['max_rows'] if 'max_rows' in args else 0 PART_SIZE = args['part_size'] if 'part_size' in args else 0 @@ -85,6 +86,7 @@ class Components : # _args['logs'] = os.sep.join([log_folder,str(part_index)]) _args['partition'] = str(part_index) _args['logger'] = {'args':{'dbname':'aou','doc':args['context']},'type':'mongo.MongoWriter'} + # # We should post the the partitions to a queue server (at least the instructions on ): # - where to get the data @@ -207,8 +209,9 @@ class Components : logger.write({'module':'process','action':'read-partition','input':info['info']}) df = pd.DataFrame(info['data']) args = info['args'] + args['gpu'] = int(info['info']['partition']) if int(args['num_gpu']) > 1 and args['gpu'] > 0: - args['gpu'] = args['gpu'] + args['num_gpu'] + args['gpu'] = args['gpu'] + args['num_gpu'] if args['gpu'] + args['num_gpu'] < 8 else 0 args['reader'] = lambda: df # # @TODO: Fix