diff --git a/data/gan.py b/data/gan.py index c61d1b1..767a24b 100644 --- a/data/gan.py +++ b/data/gan.py @@ -603,7 +603,8 @@ class Predict(GNet): # # df = pd.DataFrame(np.round(f)).astype(np.int32) - candidates.append (np.round(_matrix).astype(np.int64)) + # candidates.append (np.round(_matrix).astype(np.int64)) + candidates.append( [np.round(row).astype(int) for row in _matrix]) # return candidates[0] if len(candidates) == 1 else candidates return candidates