diff --git a/data/gan.py b/data/gan.py index e0f97b1..26f19a2 100644 --- a/data/gan.py +++ b/data/gan.py @@ -61,19 +61,16 @@ class GNet : self.logs = {} # self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu'] - # self.GPU_CHIPS = None if 'gpu' not in args else args['gpu'] - # if self.GPU_CHIPS is None: - # self.GPU_CHIPS = [0] - # if 'CUDA_VISIBLE_DEVICES' in os.environ : - # os.environ.pop('CUDA_VISIBLE_DEVICES') - # self.NUM_GPUS = 0 - # else: - # self.NUM_GPUS = len(self.GPU_CHIPS) + self.GPU_CHIPS = None if 'gpu' not in args else [args['gpu']] + if self.GPU_CHIPS is None: + self.GPU_CHIPS = [0] + if 'CUDA_VISIBLE_DEVICES' in os.environ : + os.environ.pop('CUDA_VISIBLE_DEVICES') + self.NUM_GPUS = 0 + else: + self.NUM_GPUS = len(self.GPU_CHIPS) # os.environ['CUDA_VISIBLE_DEVICES'] = str(self.GPU_CHIPS[0]) - self.NUM_GPUS = 0 if 'gpu' not in args else args['gpu'] - self.GPU_CHIPS = None if self.NUM_GPUS == 0 else [args['gpu']] - if self.GPU_CHIPS : - os.environ['CUDA_VISIBLE_DEVICES'] = str(self.GPU_CHIPS[0]) + self.PARTITION = args['partition'] if 'partition' in args else None # if self.NUM_GPUS > 1 : # os.environ['CUDA_VISIBLE_DEVICES'] = "4" diff --git a/data/maker/__init__.py b/data/maker/__init__.py index b7608d7..4c175e9 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -319,7 +319,8 @@ class Generator (Learner): _args['map'] = self._map _args['values'] = np.array(values) _args['row_count'] = self._df.shape[0] - + if self.gpu : + _args['gpu'] = self.gpu gHandler = gan.Predict(**_args) gHandler.load_meta(columns=None) _iomatrix = gHandler.apply()