diff --git a/data/gan.py b/data/gan.py index a46740a..5975255 100644 --- a/data/gan.py +++ b/data/gan.py @@ -592,7 +592,8 @@ class Predict(GNet): # The code below will insure we have some acceptable cardinal relationships between id and synthetic values # - df = pd.DataFrame(np.round(f)).astype(np.int32) + # df = pd.DataFrame(np.round(f)).astype(np.int32) + df = pd.DataFrame(np.round(f),dtype=np.int32) p = 0 not in df.sum(axis=1).values x = df.sum(axis=1).values diff --git a/pipeline.py b/pipeline.py index 22c637d..c243ec3 100644 --- a/pipeline.py +++ b/pipeline.py @@ -125,9 +125,9 @@ class Components : logger.write({"module":"train","action":"train","input":info}) data.maker.train(**_args) - if set(['drone','autopilot']) in set( list(args.keys())) : + if 'autopilot' in ( list(args.keys())) : print (['drone mode enabled ....']) - self.generate(**args) + self.generate(args) pass