diff --git a/pipeline.py b/pipeline.py index e54e746..22c637d 100644 --- a/pipeline.py +++ b/pipeline.py @@ -74,6 +74,13 @@ class Components : # pointer = args['reader'] if 'reader' in args else lambda: Components.get(**args) df = args['data'] + if 'slice' in args and 'max_rows' in args['slice']: + max_rows = args['slice']['max_rows'] + if df.shape[0] > max_rows : + print (".. slicing ") + i = np.random.choice(df.shape[0],max_rows,replace=False) + df = df.iloc[i] + # if df.shape[0] == 0 : # print ("CAN NOT TRAIN EMPTY DATASET ") @@ -117,9 +124,10 @@ class Components : logger.write({"module":"train","action":"train","input":info}) data.maker.train(**_args) + if set(['drone','autopilot']) in set( list(args.keys())) : print (['drone mode enabled ....']) - data.maker.generate(**args) + self.generate(**args) pass @@ -155,6 +163,7 @@ class Components : # reader = args['reader'] # df = reader() df = args['reader']() if 'reader' in args else args['data'] + # bounds = Components.split(df,MAX_ROWS,PART_SIZE) # if partition != '' :