You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
8.3 KiB
Python

"""
This is code that will interact with an LLM (AzureOpenAI/Ollama) leveraging langchain
"""
# _CATALOG = {
# 'sqlite':{'sql':'select tbl_name as table_name, json_group_array(y.name) as columns from sqlite_master x INNER JOIN PRAGMA_TABLE_INFO(tbl_name) y group by table_name'} ,
# 'postgresql':{'sql':"SELECT table_name, to_json(array_agg(column_name)) as columns FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = 'public' GROUP BY table_name"},
# 'bigquery':{'sql':'SELECT table_name, TO_JSON(ARRAY_AGG(column_name)) as columns FROM :dataset.INFORMATION_SCHEMA.COLUMNS','args':['dataset']},
# 'duckdb':{'sql' :'SELECT table_name, TO_JSON(ARRAY_AGG(column_name)) as columns FROM INFORMATION_SCHEMA.COLUMNS GROUP BY table_name'}
# }
# _CATALOG['sqlite3'] = _CATALOG['sqlite']
import transport
import json
import os
import cms
import pandas as pd
from langchain_openai import AzureOpenAI, AzureChatOpenAI
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain_core.messages import HumanMessage, AIMessage
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
class Agent :
def __init__(self,**_args) :
"""
:backend OpenAI or Ollama
:kwargs The initialization parameters associated with the backend
The arguments will contain temperature and model name to be used
"""
_instance = AzureChatOpenAI if _args['backend'].lower() in ['azureopenai', 'openai'] else OllamaLLM
self._llm = _instance(**_args['kwargs'])
def isSQL(self,_question):
_template = """Is the provided text a valid sql statmement. Yes or No? Your answer is a properly formatted JSON object with three attributes.
class (1 for valid sql statement and 0 for not a valid sql statement), explanation place a short explanation for the answer and original containing with the original text.
text:
{input_text}
"""
_prompt = PromptTemplate(temperature=0.1,input_variables=['input_text'],template=_template)
r = self.apply(_prompt,input_text=_question)
#
# @TODO: Make sure the response is properly formatted (output not to be trusted)
return json.loads(r)
def apply(self,_prompt,**_args):
chain = (
RunnablePassthrough.assign()
| _prompt
| self._llm
| StrOutputParser())
_out = chain.invoke(_args)
return _out #son.loads(_out)
def toSQL(self,_question,_catalog,_about):
_template="""Your task is to convert a question to an SQL query. The query will run on schema that will be provided in csv format.
Output:
The expected output will be a JSON object with two attributes sql and tables:
- "sql": the SQL query to be executed.
- "tables": list of relevant tables used.
Guidelines:
- If the question can not be answered with the provided schema return empty string in the sql attribute
- Parse the question word by word so as to be able to identify tables, fields and operations associated (Joins, filters ...)
- Under no circumstances will you provide an explanation of tables or reasoning detail.
- Avoid using subqueries, and use field names as represented in the provided schema and their data types
question:
{question}
Database schema:
{catalog}
additional information:
{context}
"""
_prompt = PromptTemplate(temperature=0.1,input_variables=['question','catalog','context'],template=_template)
r = self.apply(_prompt,question=_question,catalog=_catalog,context=_about)
# print (' ############### ------------- #####################')
# print (r)
# print (' #########################################')
#
# @TODO: Make sure the response is properly formatted (output not to be trusted)
if '```json' in r :
r = r.split('```json')[-1].replace('```','')
print (r)
return json.loads(r)
#
# We are building an interface to execute an sql query
#
def AIProxy (_label,_query,_path,_CATALOG) :
_qreader = transport.get.reader(label=_label)
_entry = transport.registry.get(_label)
_provider =_entry['provider']
_about = _entry.get('about','')
_database = _entry['database'] if 'database' in _entry else ''
if 'dataset' in _entry :
_database = _entry['dataset']
_about = f'{_about}, with dataset name {_database}'
else:
_about = f'{_about}, with the database name {_database}'
_catalog = None
_kwargs = None
_data = pd.DataFrame()
r = None
try:
#
# we should run the command here, assuming a valid query
#
_data = _qreader.apply(_query)
except Exception as e:
#
# here we are assuming we are running text to be translated as a query
# we need to make sure that we have a JSON configurator
#
if _provider in _CATALOG and os.path.exists(_path):
#
# -- reading arguments from the LLM config file, {backend,kwargs}
_kwargs = json.loads((open(_path)).read())
_agent = Agent(**_kwargs)
_qcat = _CATALOG[_provider]
if 'args' in _CATALOG[_provider] :
_entry = transport.registry.get(_label)
for _var in _CATALOG[_provider]['args'] :
if _var in _entry:
_value = _entry[_var]
_about = f'{_about}\n{_var} = {_value}'
_qcat['sql'] = _qcat['sql'].replace(f':{_var}',_value)
_catalog = _qreader.read(**_qcat).to_csv(index=0)
_about = f"The queries will run on {_provider} database.\n{_about}"
r = _agent.toSQL(_query,_catalog, _about)
try:
_data = _qreader.apply(r['sql'])
except Exception as e1:
_data = pd.DataFrame()
pass
#
# returning the data and the information needed
#
_data = _data.astype(str).to_dict(orient='split')
del _data['index']
if r :
return {'data':_data,'query':r['sql']}
return _data
# if not _path :
# #
# # exececute the query as is !
# pass
# else:
# if _provider in _CATALOG and os.path.exists(_path) :
# #
# # As the agent if it is an SQL Query
# f = open(_path)
# _kwargs = json.loads( f.read() )
# f.close()
# try:
# print ([f"Running Model {_kwargs['kwargs']['model']}"])
# _agent = Agent(**_kwargs)
# # r = _agent.isSQL(_query)
# # print (f"****** {_query}\n{r['class']}")
# # if r and int(r['class']) == 0 :
# _catalog = _qreader.read(**_CATALOG[_provider]).to_csv(index=0)
# # print (['****** TABLES FOUND ', _catalog])
# _about = _about if _about else ''
# r = _agent.toSQL(_query,_catalog,f"This is a {_provider} database, and queries generated should account for this. {_about}")
# _query = r['sql']
# # else:
# # #
# # # provided an sql query
# # pass
# except Exception as e:
# #
# # AI Service is unavailable ... need to report this somehow
# print (e)
# else:
# #
# # Not in catalog ...
# pass
# _data = _qreader.apply(_query)
# if _data.shape[0] :
# _data = _data.astype(str).to_dict(orient='split')
# if 'index' in _data :
# del _data['index']
# return _data
@cms.Plugin(mimetype="application/json",method="POST")
def apply (**_args):
_request = _args['request']
_label = _request.json['label']
_query = _request.json['query']
_source = _args['config']['system']['source']
_path = _source.get('llm',None)
_CATALOGS = _source.get('catalogs',{})
return AIProxy(_label,_query,_path,_CATALOGS)
@cms.Plugin(mimetype="text/plain")
def enabled(**_args):
_config = _args['config']
return str(int('llm' in _config['system']['source'] ))