parent
							
								
									f85d5fd870
								
							
						
					
					
						commit
						dcc55eb1fb
					
				@ -1,10 +1,27 @@
 | 
				
			|||||||
import pandas as pd
 | 
					import pandas as pd
 | 
				
			||||||
import data.maker
 | 
					import data.maker
 | 
				
			||||||
 | 
					from data.params import SYS_ARGS
 | 
				
			||||||
df      = pd.read_csv('sample.csv')
 | 
					import json
 | 
				
			||||||
column  = 'gender'
 | 
					from scipy.stats import wasserstein_distance as wd
 | 
				
			||||||
id      = 'id' 
 | 
					import risk
 | 
				
			||||||
context = 'demo'
 | 
					import numpy as np
 | 
				
			||||||
store = {"type":"mongo.MongoWriter","args":{"host":"localhost:27017","dbname":"GAN"}}
 | 
					if 'config' in SYS_ARGS :
 | 
				
			||||||
max_epochs = 11
 | 
					    ARGS = json.loads(open(SYS_ARGS['config']).read())
 | 
				
			||||||
data.maker.train(store=store,max_epochs=max_epochs,context=context,data=df,column=column,id=id,logs='foo')
 | 
					    if 'generate' not in SYS_ARGS :
 | 
				
			||||||
 | 
					        data.maker.train(**ARGS)    
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        _df = data.maker.generate(**ARGS)
 | 
				
			||||||
 | 
					        odf = pd.read_csv (ARGS['data'])
 | 
				
			||||||
 | 
					        odf.columns = [name.lower() for name in odf.columns]
 | 
				
			||||||
 | 
					        column = [ARGS['column'] ] #+ ARGS['id']
 | 
				
			||||||
 | 
					        print (column)
 | 
				
			||||||
 | 
					        print (_df[column].risk.evaluate())
 | 
				
			||||||
 | 
					        print (odf[column].risk.evaluate())
 | 
				
			||||||
 | 
					        _x = pd.get_dummies(_df[column]).values
 | 
				
			||||||
 | 
					        y  = pd.get_dummies(odf[column]).values
 | 
				
			||||||
 | 
					        N = _df.shape[0]
 | 
				
			||||||
 | 
					        print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
 | 
				
			||||||
 | 
					        # column = SYS_ARGS['column']
 | 
				
			||||||
 | 
					        # odf = open(SYS_ARGS['data'])
 | 
				
			||||||
					Loading…
					
					
				
		Reference in new issue