diff --git a/pipeline.py b/pipeline.py index 6f28eac..8b1dd9e 100644 --- a/pipeline.py +++ b/pipeline.py @@ -136,6 +136,8 @@ class Components : # We need to make sure that continuous columns are removed if x_cols : _args['data'] = df[list(set(df.columns) - set(x_cols))] + if 'gpu' in args : + _args['gpu'] = args['gpu'] data.maker.train(**_args) if 'autopilot' in ( list(args.keys())) :