diff --git a/pipeline.py b/pipeline.py index 72dea06..7082b71 100644 --- a/pipeline.py +++ b/pipeline.py @@ -239,7 +239,7 @@ class Components : real_df = pd.DataFrame() if x_cols : args['data'] = args['data'][list(set(args['data'].columns) - set(x_cols))] - real_df = args[x_cols].copy() + real_df = args['data'][x_cols].copy() args['candidates'] = 1 if 'candidates' not in args else int(args['candidates']) if 'gpu' in args :