diff --git a/data/gan.py b/data/gan.py index 5975255..4a0fa48 100644 --- a/data/gan.py +++ b/data/gan.py @@ -593,7 +593,7 @@ class Predict(GNet): # # df = pd.DataFrame(np.round(f)).astype(np.int32) - df = pd.DataFrame(np.round(f),dtype=np.int32) + df = pd.DataFrame(np.round(f),dtype=int) p = 0 not in df.sum(axis=1).values x = df.sum(axis=1).values diff --git a/pipeline.py b/pipeline.py index c243ec3..7017592 100644 --- a/pipeline.py +++ b/pipeline.py @@ -163,6 +163,13 @@ class Components : # reader = args['reader'] # df = reader() df = args['reader']() if 'reader' in args else args['data'] + + if 'slice' in args and 'max_rows' in args['slice']: + max_rows = args['slice']['max_rows'] + if df.shape[0] > max_rows : + print (".. slicing ") + i = np.random.choice(df.shape[0],max_rows,replace=False) + df = df.iloc[i] # bounds = Components.split(df,MAX_ROWS,PART_SIZE)