dev
Steve L. Nyemba 4 years ago
parent 46f2fd7be4
commit 43873697a0

@ -58,7 +58,14 @@ class GNet :
self.layers.normalize = self.normalize
self.logs = {}
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
# self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
self.GPU_CHIPS = None if 'gpu' not in args else args['gpu']
if self.GPU_CHIPS is None:
self.GPU_CHIPS = [0]
if 'CUDA_VISIBLE_DEVICES' in os.environ :
os.environ.pop('CUDA_VISIBLE_DEVICES')
self.NUM_GPUS = len(self.GPU_CHIPS)
self.PARTITION = args['partition']
# if self.NUM_GPUS > 1 :
# os.environ['CUDA_VISIBLE_DEVICES'] = "4"
@ -150,6 +157,7 @@ class GNet :
"D_STRUCTURE":self.D_STRUCTURE,
"G_STRUCTURE":self.G_STRUCTURE,
"NUM_GPUS":self.NUM_GPUS,
"GPU_CHIPS":self.GPU_CHIPS,
"NUM_LABELS":self.NUM_LABELS,
"MAX_EPOCHS":self.MAX_EPOCHS,
"ROW_COUNT":self.ROW_COUNT
@ -443,7 +451,7 @@ class Train (GNet):
# - abstract hardware specification
# - determine if the GPU/CPU are busy
#
for i in range(self.NUM_GPUS):
for i in self.GPU_CHIPS : #range(self.NUM_GPUS):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
if self._LABEL is not None :

@ -90,14 +90,14 @@ def train (**_args):
#
# Let us prepare the data by calling the utility function
#
if 'file' in _args :
#
# We are reading data from a file
_args['data'] = pd.read_csv(_args['file'])
else:
#
# data will be read from elsewhere (a data-store)...
pass
# if 'file' in _args :
# #
# # We are reading data from a file
# _args['data'] = pd.read_csv(_args['file'])
# else:
# #
# # data will be read from elsewhere (a data-store)...
# pass
# if 'ignore' in _args and 'columns' in _args['ignore']:
_inputhandler = prepare.Input(**_args)
@ -107,6 +107,7 @@ def train (**_args):
if 'store' in _args :
#
# This
args['store'] = copy.deepcopy(_args['store']['logs'])
args['store']['args']['doc'] = _args['context']
logger = factory.instance(**args['store'])

@ -13,7 +13,7 @@ import transport
import json
import pandas as pd
import numpy as np
import cupy as cp
# import cupy as cp
import sys
import os
# from multiprocessing import Process, Queue
@ -62,7 +62,7 @@ class Input :
self._schema = _args['schema'] if 'schema' in _args else {}
self.df = _args['data']
if 'sql' not in _args :
# self._initdata(**_args)
self._initdata(**_args)
#
pass
else:
@ -70,12 +70,12 @@ class Input :
self._map = {} if 'map' not in _args else _args['map']
# self._metadf = pd.DataFrame(self.df[self._columns].dtypes.values.astype(str)).T #,self._columns]
# self._metadf.columns = self._columns
if 'gpu' in _args and 'GPU' in os.environ:
# if 'gpu' in _args and 'GPU' in os.environ:
np = cp
index = int(_args['gpu'])
np.cuda.Device(index).use()
print(['..:: GPU ',index])
# np = cp
# index = int(_args['gpu'])
# np.cuda.Device(index).use()
# print(['..:: GPU ',index])
def _initsql(self,**_args):
"""
@ -107,6 +107,8 @@ class Input :
row_count = self.df.shape[0]
cols = None if 'columns' not in _args else _args['columns']
self.columns = self.df.columns.tolist()
self._io = []
if 'columns' in _args :
self._columns = _args['columns']
else:
@ -115,6 +117,8 @@ class Input :
_df = pd.DataFrame(self.df.apply(lambda col: col.dropna().unique().size )).T
MIN_SPACE_SIZE = 2
self._columns = cols if cols else _df.apply(lambda col:None if col[0] == row_count or col[0] < MIN_SPACE_SIZE else col.name).dropna().tolist()
self._io = _df.to_dict(orient='records')
def _initdata(self,**_args):
"""
This function will initialize the class with a data-frame and columns of interest (if any)

@ -4,7 +4,7 @@ import sys
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.4.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
args = {"name":"data-maker","version":"1.4.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'

Loading…
Cancel
Save