You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
912 B
Python
27 lines
912 B
Python
import pandas as pd
|
|
import data.maker
|
|
from data.params import SYS_ARGS
|
|
import json
|
|
from scipy.stats import wasserstein_distance as wd
|
|
import risk
|
|
import numpy as np
|
|
if 'config' in SYS_ARGS :
|
|
ARGS = json.loads(open(SYS_ARGS['config']).read())
|
|
if 'generate' not in SYS_ARGS :
|
|
data.maker.train(**ARGS)
|
|
else:
|
|
#
|
|
#
|
|
_df = data.maker.generate(**ARGS)
|
|
odf = pd.read_csv (ARGS['data'])
|
|
odf.columns = [name.lower() for name in odf.columns]
|
|
column = [ARGS['column'] ] #+ ARGS['id']
|
|
print (column)
|
|
print (_df[column].risk.evaluate())
|
|
print (odf[column].risk.evaluate())
|
|
_x = pd.get_dummies(_df[column]).values
|
|
y = pd.get_dummies(odf[column]).values
|
|
N = _df.shape[0]
|
|
print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
|
|
# column = SYS_ARGS['column']
|
|
# odf = open(SYS_ARGS['data']) |