Spaces:
Paused
Paused
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
#from torch import tensor as torch_tensor | |
#from datasets import load_dataset | |
from langchain.llms import OpenAI | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
from langchain import LLMMathChain, SQLDatabase, SQLDatabaseChain, LLMChain | |
from langchain.agents import initialize_agent, Tool | |
import sqlite3 | |
#import pandas as pd | |
import json | |
import chromadb | |
# database | |
cxn = sqlite3.connect('./data/mbr.db') | |
"""# import models""" | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens | |
#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
"""# import datasets""" | |
# dataset = load_dataset("gfhayworth/hack_policy", split='train') | |
# mypassages = list(dataset.to_pandas()['psg']) | |
# dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train') | |
# dataset_embed_pd = dataset_embed.to_pandas() | |
# mycorpus_embeddings = torch_tensor(dataset_embed_pd.values) | |
########################################################################################################################### | |
"""# set up vector db""" | |
from chromadb.config import Settings | |
chroma_client = chromadb.Client(settings=Settings( | |
chroma_db_impl="duckdb+parquet", | |
persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory | |
)) | |
collection = chroma_client.get_collection(name="benefit_collection") | |
def vdb_rslt(qry,src,top_k=20): | |
results = collection.query( | |
query_embeddings=[ bi_encoder.encode(qry) ], | |
n_results=top_k, | |
where={"source": src}, | |
) | |
return results | |
################################################################################################################################## | |
# Semantic Search Functions | |
def rtrv(qry, src = 'H1036236000SB23.pdf', top_k=20): | |
rslts = vdb_rslt(qry,src, top_k) | |
return rslts | |
def rernk(query, collection=collection, top_k=20, top_n = 5): | |
rtrv_rslts = rtrv(query, top_k=top_k) | |
rtrv_ids = rtrv_rslts.get('ids')[0] | |
rtrv_docs = rtrv_rslts.get('documents')[0] | |
##### Re-Ranking ##### | |
cross_inp = [[query, doc] for doc in rtrv_docs] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
combined = list(zip(rtrv_ids, list(cross_scores))) | |
sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True) | |
sorted_ids = [t[0] for t in sorted_tuples[:top_n]] | |
predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"]) | |
return predictions | |
def get_text_fmt(qry): | |
prediction_text = [] | |
predictions = rernk(qry, collection=collection, top_k=20, top_n = 5) | |
docs = predictions['documents'] | |
meta = predictions['metadatas'] | |
for i in range(len(docs)): | |
result = Document(page_content=docs[i], metadata=meta[i]) | |
prediction_text.append(result) | |
return prediction_text | |
################################################################################################################################## | |
"""# LLM based qa functions""" | |
template = """You are a friendly AI assistant for the insurance company Humana. | |
Given the following extracted parts of a long document and a question, create a succinct final answer. | |
If you don't know the answer, just say that you don't know. Don't try to make up an answer. | |
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana. | |
QUESTION: {question} | |
========= | |
{summaries} | |
========= | |
FINAL ANSWER:""" | |
PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"]) | |
chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT, verbose=True) | |
def get_llm_response(message): | |
mydocs = get_text_fmt(message) | |
responses = chain_qa({"input_documents":mydocs, "question":message}) | |
return responses | |
"""# Database query""" | |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db") | |
llm = OpenAI(temperature=0) | |
# default model | |
# model_name: str = "text-davinci-003" | |
# instruction fine-tuned, sometimes referred to as GPT-3.5 | |
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True) | |
def db_qry(qry): | |
responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo | |
return responses | |
#db_qry('how many footcare visits have I had?') | |
"""## Math | |
- default version | |
""" | |
llm_math_chain = LLMMathChain(llm=llm, verbose=True) | |
#llm_math_chain.run('what is the square root of 49?') | |
"""# Greeting""" | |
template = """You are an AI assistant for the insurance company Humana. | |
Your name is Jarvis and you were created by Humana's AI research team. | |
Offer polite, friendly greetings and brief small talk. | |
Respond to thanks with, 'Glad to help.' | |
If the question is not about Humana, politely guide the user to ask questions about Humana insurance benefits. | |
QUESTION: {question} | |
========= | |
FINAL ANSWER:""" | |
greet_prompt = PromptTemplate(template=template, input_variables=["question"]) | |
greet_llm = LLMChain(prompt=greet_prompt, llm=llm, verbose=True) | |
"""# MRKL Chain""" | |
tools = [ | |
Tool( | |
name = "Benefit", | |
func=get_llm_response, | |
description='''useful for when you need to answer questions about plan benefits, premiums and payments. | |
This tool shows how much of a benefit is available in the plan. | |
You should ask targeted questions''' | |
), | |
Tool( | |
name="Calculator", | |
func=llm_math_chain.run, | |
description="useful for when you need to answer questions about math" | |
), | |
Tool( | |
name="Member DB", | |
func=db_qry, | |
description='''useful for when you need to answer questions about member details such their name, id and accumulated use of services. | |
This tool shows how much a benfit has already been consumed. | |
Input should be in the form of a question containing full context''' | |
), | |
Tool( | |
name="Greeting", | |
func=greet_llm.run, | |
description="useful for when you need to respond to greetings, thanks, answer questions about yourself, and make small talk" | |
), | |
] | |
mrkl = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, return_intermediate_steps=True, max_iterations=5, early_stopping_method="generate") | |
def mrkl_rspnd(qry): | |
response = mrkl({"input":str(qry) }) | |
return response | |
def get_cot(r): | |
cot = '<p>' | |
try: | |
intermedObj = r['intermediate_steps'] | |
cot +='<b>Input:</b> '+r['input']+'<br>' | |
for agnt_action, obs in intermedObj: | |
al = '<br> '.join(agnt_action.log.split('\n') ) | |
cot += '<b>AI chain of thought:</b> '+ al +'<br>' | |
if type(obs) is dict: | |
if obs.get('input_documents') is not None: | |
for d in obs['input_documents']: | |
cot += ' '+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+'''<span style="color: blue;">'''+str(d.metadata['page'])+'</span></a> '+'<br>' | |
cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>' | |
elif obs.get('intermediate_steps') is not None: | |
cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>' | |
else: | |
pass | |
else: | |
cot += '<b>Observation:</b> '+str(obs) +'<br><br>' | |
except: | |
pass | |
cot += '</p>' | |
return cot | |
def chat(message, history): | |
history = history or [] | |
message = message.lower() | |
response = mrkl_rspnd(message) | |
cot = get_cot(response) | |
history.append((message, response['output'])) | |
return history, history, cot | |