|
|
|
|
|
|
|
import os |
|
import gradio as gr |
|
|
|
from operator import itemgetter |
|
|
|
|
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.runnables import RunnableParallel,RunnablePassthrough,RunnableLambda |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
|
|
|
from pinecone import Pinecone, ServerlessSpec |
|
from langchain_pinecone import PineconeVectorStore |
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
|
|
setid = "global" |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=os.getenv("EMBEDDINGS_MODEL")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = ChatGroq(model_name='mixtral-8x7b-32768') |
|
|
|
|
|
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) |
|
index = pc.Index(setid) |
|
vectorstore = PineconeVectorStore(index, embeddings, "text") |
|
retriever = vectorstore.as_retriever(kwargs={"k":5}) |
|
|
|
|
|
template_no_history = """Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
PROMPT_NH = ChatPromptTemplate.from_template(template_no_history) |
|
|
|
template_with_history = """Given the following conversation history, answer the follow up question: |
|
Chat History: |
|
{chat_history} |
|
|
|
Question: {question} |
|
""" |
|
PROMPT_WH = ChatPromptTemplate.from_template(template_with_history) |
|
|
|
|
|
def pipeLog(x): |
|
print("***", x) |
|
return x |
|
|
|
|
|
setup_and_retrieval = RunnableParallel( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
) |
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
rag_chain_from_docs = ( |
|
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) |
|
| PROMPT_NH |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
rag_chain_with_source = RunnableParallel( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
).assign(answer=rag_chain_from_docs) |
|
|
|
|
|
def rag_query(question: str, history: list[list[str]]): |
|
if len(history)==0: |
|
|
|
|
|
response = rag_chain_with_source.invoke(question) |
|
sources = [ doc.metadata['source'] for doc in response['context'] ] |
|
print(response, '\n', sources) |
|
return response['answer'] |
|
else: |
|
chat_history = "" |
|
for l in history: |
|
chat_history += " : ".join(l) |
|
chat_history += "\n" |
|
chain = ( |
|
{ "chat_history": itemgetter('chat_history'), "question": itemgetter('question') } |
|
| PROMPT_WH |
|
| pipeLog |
|
| model |
|
) |
|
response = chain.invoke({ "chat_history": chat_history, "question": question }) |
|
return response.content |
|
|
|
|
|
|
|
|
|
def pipeLog(s:str, x): |
|
print(s, x) |
|
return x |
|
pipe_a = RunnableLambda(lambda x: pipeLog("a:",x)) |
|
pipe_b = RunnableLambda(lambda x: pipeLog("b:",x)) |
|
|
|
|
|
|
|
contextualize_q_system_prompt = """Given a chat history and the latest user question \ |
|
which might reference context in the chat history, formulate a standalone question \ |
|
which can be understood without the chat history. Do NOT answer the question, \ |
|
just reformulate it if needed and otherwise return it as is.""" |
|
|
|
contextualize_q_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", contextualize_q_system_prompt), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{question}"), |
|
] |
|
) |
|
|
|
contextualize_q_chain = contextualize_q_prompt | model | StrOutputParser() |
|
|
|
|
|
|
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. |
|
Use the following pieces of retrieved context to answer the question. |
|
If you don't know the answer, just say that you don't know. |
|
Use three sentences maximum and keep the answer concise. |
|
|
|
{context}""" |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{question}"), |
|
] |
|
) |
|
|
|
def contextualized_question(input: dict): |
|
if input.get("chat_history"): |
|
return contextualize_q_chain |
|
else: |
|
return input["question"] |
|
|
|
|
|
rag_chain = ( |
|
RunnablePassthrough.assign( |
|
context=pipe_b | contextualized_question | retriever | format_docs |
|
) |
|
| qa_prompt |
|
| model |
|
) |
|
|
|
rag_chain_with_source = RunnableParallel( |
|
{"xx": pipe_a, "context": itemgetter('question')|retriever, "question": itemgetter('question'), "chat_history": itemgetter('chat_history') } |
|
).assign(answer=rag_chain) |
|
|
|
|
|
|
|
def rag_query_2(question: str, history: list[list[str]]): |
|
response = rag_chain_with_source.invoke({ 'question':question, 'chat_history':history }) |
|
print(response) |
|
|
|
|
|
return response['answer'].content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.ChatInterface( |
|
rag_query_2, |
|
title="RAG Chatbot demo", |
|
description="A chatbot doing Retrieval Augmented Generation, backed by a Pinecone vector database" |
|
).launch() |
|
|
|
|
|
|