Spaces:
Running
Running
from langchain_openai import AzureChatOpenAI | |
from langchain_openai import AzureOpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
import os | |
import panel as pn | |
import param | |
#from langchain.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain.schema import HumanMessage | |
from langchain_core.messages import SystemMessage | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.chat_message_histories import ChatMessageHistory | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
import uuid | |
# download existing Chroma database and create a chatbot | |
def load_db(k): | |
# define embedding | |
vector_embed = "text-embedding-ada-002" | |
embeddings = AzureOpenAIEmbeddings( | |
azure_deployment=vector_embed, | |
openai_api_version="2023-05-15", | |
) | |
# download Chroma vector database | |
persist_directory = 'chroma_1500ch_new/' | |
db = Chroma(persist_directory=persist_directory, | |
embedding_function=embeddings) | |
# define LLM | |
llm = AzureChatOpenAI( | |
openai_api_version="2023-05-15", | |
azure_deployment="gpt-35-turbo-0613", | |
model_version="0613", | |
temperature=0 | |
) | |
llm4 = AzureChatOpenAI( | |
openai_api_version="2023-05-15", | |
azure_deployment="gpt-4o", | |
model_version="2024-05-13", | |
temperature=0 | |
) | |
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k}) | |
### Contextualize question ### | |
contextualize_q_system_prompt = (""" | |
Given a chat history and the latest user question | |
which might reference context in the chat history, | |
formulate a standalone question which can be understood | |
without the chat history. Do NOT answer the question, | |
just reformulate it if needed and otherwise return it as is. | |
""") | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
history_aware_retriever = create_history_aware_retriever( | |
llm, | |
retriever, | |
contextualize_q_prompt | |
) | |
### Answer question ### | |
system_prompt = (""" | |
You are an AI assistant for question-answering tasks. | |
Use the following pieces of retrieved context to answer the question. | |
You must ONLY use the information provided in the context to answer the question. | |
If the context does not contain relevant information to answer the question, | |
respond with 'I don't have enough information to answer this question based on the provided context.' | |
Do not use any external knowledge. Do not try to make up an answer. | |
If you cannot quote from the context, you cannot answer the question. \n\n | |
{context} | |
""" | |
) | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
] | |
) | |
question_answer_chain = create_stuff_documents_chain(llm4, qa_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | |
### Statefully manage chat history ### | |
store = {} | |
def get_session_history(session_id: str) -> BaseChatMessageHistory: | |
if session_id not in store: | |
store[session_id] = ChatMessageHistory() | |
return store[session_id] | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_session_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer", | |
) | |
return conversational_rag_chain | |
# create a class for the app | |
class cbfs(param.Parameterized): | |
chat_history = param.List([]) | |
answer = param.String("") | |
db_query = param.String("") | |
db_response = param.List([]) | |
def __init__(self, **params): | |
super(cbfs, self).__init__(**params) | |
self.panels = [] | |
self.qa = load_db(5) # indicate number of chunks | |
self.clr_history() | |
def convchain(self, query): | |
if not query: | |
return pn.WidgetBox(pn.Row('You:', pn.pane.Markdown("", width=600)), scroll=True) | |
result = self.qa.invoke({"input": query}, config={"configurable": {"session_id": self.session_id}}) | |
self.chat_history = result["chat_history"] | |
self.db_query = result["input"] | |
self.db_response = result["context"] | |
self.answer = result['answer'] | |
self.panels.extend([ | |
pn.Row('You:', pn.pane.Markdown(query, width=600)), | |
pn.Row('ChadGPT:', pn.pane.Markdown(self.answer, width=600, styles={'background-color': '#F6F6F6'})) | |
]) | |
inp.value = '' # clears loading indicator when cleared | |
return pn.WidgetBox(*self.panels,scroll=True) | |
def get_lquest(self): | |
if not self.db_query : | |
return pn.Column( | |
pn.Row(pn.pane.Markdown(f"Last question to DB:", styles={'background-color': '#F6F6F6'})), | |
pn.Row(pn.pane.Str("No DB accesses yet")) | |
) | |
return pn.Column( | |
pn.Row(pn.pane.Markdown(f"DB query:", styles={'background-color': '#F6F6F6'})), | |
pn.pane.Str(self.db_query ) | |
) | |
def get_sources(self): | |
if not self.db_response: | |
return | |
rlist=[pn.Row(pn.pane.Markdown(f"Result of DB lookup:", styles={'background-color': '#F6F6F6'}))] | |
for doc in self.db_response: | |
rlist.append(pn.Row(pn.pane.Str(doc))) | |
return pn.WidgetBox(*rlist, width=600, scroll=True) | |
def get_chats(self): | |
if not self.chat_history: | |
return pn.WidgetBox(pn.Row(pn.pane.Str("No history yet")), width=600, scroll=True) | |
rlist=[pn.Row(pn.pane.Markdown(f"Current chat history:", styles={'background-color': '#F6F6F6'}))] | |
for exchange in self.chat_history: | |
rlist.append(pn.Row(pn.pane.Str(exchange.content))) | |
return pn.WidgetBox(*rlist, width=600, scroll=True) | |
def clr_history(self, count=0): | |
self.chat_history = [] | |
self.session_id = str(uuid.uuid4()) | |
return | |
pn.extension() | |
cb = cbfs() | |
button_clearhistory = pn.widgets.Button(name="Clear History", button_type='warning') | |
button_clearhistory.on_click(cb.clr_history) | |
inp = pn.widgets.TextInput(placeholder='Enter your question here …', width=700) | |
conversation = pn.bind(cb.convchain, inp) | |
image_url = "https://tradetalkspodcast.com/wp-content/uploads/2019/09/Bown_2019-08-22_email-header-medium.png" | |
jpg_pane = pn.pane.PNG(object=image_url) | |
tab1 = pn.Column( | |
pn.Row(inp), | |
pn.layout.Divider(), | |
pn.panel(conversation, loading_indicator=True, height=400), | |
pn.layout.Divider(), | |
) | |
tab2 = pn.Column( | |
pn.Row(pn.pane.Markdown("""DB is built off of transcripts of Trade Talks recordings. \ | |
Some transcripts are provided by Trade Talks. Others are transcribed | |
from the recordings using OpenAI's Whisper model.""")), | |
pn.layout.Divider(), | |
pn.panel(cb.get_lquest), | |
pn.layout.Divider(), | |
pn.panel(cb.get_sources), | |
) | |
tab3 = pn.Column( | |
pn.panel(cb.get_chats), | |
pn.layout.Divider(), | |
) | |
tab4 = pn.Column( | |
pn.Row(button_clearhistory, pn.pane.Markdown("Clears chat history. Use it to start a new topic.")), | |
pn.layout.Divider() | |
) | |
dashboard = pn.Column( | |
pn.Row(pn.pane.Markdown('# ChadGPT: Your AI Guide to Trade Talks Wisdom')), | |
pn.Row(jpg_pane.clone(width=700)), | |
pn.Tabs(('Conversation', tab1), ('Database', tab2), ('Chat History', tab3), ('Reset', tab4)) | |
) | |
dashboard.servable() |