diff --git a/data/gan.py b/data/gan.py index 80c3f8e..7952f23 100644 --- a/data/gan.py +++ b/data/gan.py @@ -584,7 +584,7 @@ class Predict(GNet): p = 0 not in df.sum(axis=1).values x = df.sum(axis=1).values - if x.max() == 1 and np.divide( np.sum(x), x.size) > .9 or p and np.sum(x) == x.size and x.size == self.values.size: + if np.divide( np.sum(x), x.size) > .9 or p and np.sum(x) == x.size and x.size == self.values.size: ratio.append(np.divide( np.sum(x), x.size)) found.append(df) if i == CANDIDATE_COUNT: