parent
f85d5fd870
commit
dcc55eb1fb
@ -1,10 +1,27 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import data.maker
|
import data.maker
|
||||||
|
from data.params import SYS_ARGS
|
||||||
df = pd.read_csv('sample.csv')
|
import json
|
||||||
column = 'gender'
|
from scipy.stats import wasserstein_distance as wd
|
||||||
id = 'id'
|
import risk
|
||||||
context = 'demo'
|
import numpy as np
|
||||||
store = {"type":"mongo.MongoWriter","args":{"host":"localhost:27017","dbname":"GAN"}}
|
if 'config' in SYS_ARGS :
|
||||||
max_epochs = 11
|
ARGS = json.loads(open(SYS_ARGS['config']).read())
|
||||||
data.maker.train(store=store,max_epochs=max_epochs,context=context,data=df,column=column,id=id,logs='foo')
|
if 'generate' not in SYS_ARGS :
|
||||||
|
data.maker.train(**ARGS)
|
||||||
|
else:
|
||||||
|
#
|
||||||
|
#
|
||||||
|
_df = data.maker.generate(**ARGS)
|
||||||
|
odf = pd.read_csv (ARGS['data'])
|
||||||
|
odf.columns = [name.lower() for name in odf.columns]
|
||||||
|
column = [ARGS['column'] ] #+ ARGS['id']
|
||||||
|
print (column)
|
||||||
|
print (_df[column].risk.evaluate())
|
||||||
|
print (odf[column].risk.evaluate())
|
||||||
|
_x = pd.get_dummies(_df[column]).values
|
||||||
|
y = pd.get_dummies(odf[column]).values
|
||||||
|
N = _df.shape[0]
|
||||||
|
print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
|
||||||
|
# column = SYS_ARGS['column']
|
||||||
|
# odf = open(SYS_ARGS['data'])
|
Loading…
Reference in new issue