You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
data-maker/data/maker/__init__.py

101 lines
3.6 KiB
Python

"""
(c) 2019 Data Maker, hiplab.mc.vanderbilt.edu
version 1.0.0
This package serves as a proxy to the overall usage of the framework.
This package is designed to generate synthetic data from a dataset from an original dataset using deep learning techniques
@TODO:
- Make configurable GPU, EPOCHS
"""
import pandas as pd
import numpy as np
import data.gan as gan
from transport import factory
from data.bridge import Binary
import threading as thread
def train (**args) :
"""
This function is intended to train the GAN in order to learn about the distribution of the features
:column columns that need to be synthesized (discrete)
:logs where the output of the (location on disk)
:id identifier of the dataset
:data data-frame to be synthesized
:context label of what we are synthesizing
"""
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
# column_id = args['id']
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
df.columns = [name.lower() for name in df.columns]
#
# If we have several columns we will proceed one at a time (it could be done in separate threads)
# @TODO : Consider performing this task on several threads/GPUs simulataneously
#
handler = Binary()
# args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
# args['label'] = handler.Export(df[[column_id]])
# args['label'] = np.ones(df.shape[0]).reshape(df.shape[0],1)
for col in column :
args['real'] = pd.get_dummies(df[col].fillna('')).astype(np.float32).values
# args['real'] = handler.Export(df[[col]])
args['column'] = col
args['context'] = col
context = args['context']
if 'store' in args :
args['store']['args']['doc'] = context
logger = factory.instance(**args['store'])
args['logger'] = logger
else:
logger = None
trainer = gan.Train(**args)
trainer.apply()
def post(**args):
"""
This uploads the tensorflow checkpoint to a data-store (mongodb, biguqery, s3)
"""
pass
def get(**args):
"""
This function will restore a checkpoint from a persistant storage on to disk
"""
pass
def generate(**args):
"""
This function will generate a synthetic dataset on the basis of a model that has been learnt for the dataset
@return pandas.DataFrame
:data data-frame to be synthesized
:column columns that need to be synthesized (discrete)
:id column identifying an entity
:logs location on disk where the learnt knowledge of the dataset is
"""
# df = args['data']
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
# column_id = args['id']
#
#@TODO:
# If the identifier is not present, we should fine a way to determine or make one
#
_df = df.copy()
for col in column :
args['context'] = col
args['column'] = col
values = df[col].unique().tolist()
args['values'] = values
args['row_count'] = df.shape[0]
#
# we can determine the cardinalities here so we know what to allow or disallow
handler = gan.Predict (**args)
handler.load_meta(col)
# handler.ROW_COUNT = df[col].shape[0]
r = handler.apply()
# print (r)
_df[col] = r[col]
# break
return _df