web-chat / app.py
Bryan Lincoln
fix: update configs
1a92f9a
import gradio as gr
from langchain.chains import (
ConversationalRetrievalChain,
LLMChain,
MapReduceDocumentsChain,
ReduceDocumentsChain,
StuffDocumentsChain,
)
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain_community.chat_models import ChatOpenAI
from langchain_community.document_loaders import WebBaseLoader
def wait_for_summarization(url):
return [(None, f"Please wait while I summarize the contents of {url}...")]
def load_page(url, api_key, history):
global docs, summary, llm
loader = WebBaseLoader(url)
docs = loader.load()
llm = ChatOpenAI(
model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key
)
map_template = """The following is a set of snippets from a web page:
{docs}
Based on this list of snippets, please identify the main themes
Helpful Answer:"""
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Reduce
reduce_template = """The following is set of summaries of a web page:
{docs}
Take these and distill it into a final, consolidated summary of the main themes.
Helpful Answer:"""
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="docs"
)
# Combines and iteratively reduces the mapped documents
reduce_documents_chain = ReduceDocumentsChain(
# This is final chain that is called.
combine_documents_chain=combine_documents_chain,
# If documents exceed context for `StuffDocumentsChain`
collapse_documents_chain=combine_documents_chain,
# The maximum number of tokens to group documents into.
token_max=4000,
)
# Combining documents by mapping a chain over them, then combining results
map_reduce_chain = MapReduceDocumentsChain(
# Map chain
llm_chain=map_chain,
# Reduce chain
reduce_documents_chain=reduce_documents_chain,
# The variable name in the llm_chain to put the documents in
document_variable_name="docs",
# Return the results of the map steps in the output
return_intermediate_steps=False,
)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=0
)
split_docs = text_splitter.split_documents(docs)
summary = map_reduce_chain.run(split_docs)
return history + [(None, summary)]
def prepare_chat(api_key, history):
global docs, summary, llm, qa
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128)
documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
vectorstore = Chroma.from_documents(documents, embeddings)
retriever = vectorstore.as_retriever(
search_type="similarity", search_kwargs={"k": 6}
)
qa_prompt_template = (
"""As an AI assistant you help in answering questions about the contents of a web page.
The summary of the current web page is this:
"""
+ summary
+ """
Also, consider this additional context that may be relevant for the user's question:
{context}
Please answer following question: {question}"""
)
qa_prompt = PromptTemplate(
template=qa_prompt_template, input_variables=["context", "question"]
)
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": qa_prompt},
)
return history + [(None, "You can now ask me specific questions about the page.")]
def chatbot_function(message, history):
global qa
return "", history + [(message, qa.run(message))]
def build_demo():
with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Row() as config_row:
with gr.Column():
api_key_box = gr.Textbox(
show_label=False,
placeholder="OpenAI API Key",
container=False,
autofocus=True,
)
url_box = gr.Textbox(
show_label=False,
placeholder="URL",
container=False,
)
load_btn = gr.Button(value="Load", variant="primary")
with gr.Row(visible=False) as chat_row:
with gr.Column():
with gr.Row():
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Web Chat",
height=550,
)
with gr.Row(visible=False) as inputs_row:
with gr.Column(scale=8):
text_box = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
autofocus=True,
container=False,
)
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(
value="Send",
variant="primary",
)
load_btn.click(
lambda: gr.update(visible=False),
outputs=[config_row],
).then(
lambda: gr.update(visible=True),
outputs=[chat_row],
).then(
wait_for_summarization,
inputs=[url_box],
outputs=[chatbot],
).then(
load_page,
inputs=[url_box, api_key_box, chatbot],
outputs=[chatbot],
).then(
prepare_chat,
inputs=[api_key_box, chatbot],
outputs=[chatbot],
).then(
lambda: gr.update(visible=True),
outputs=[inputs_row],
)
text_box.submit(
chatbot_function,
[text_box, chatbot],
[text_box, chatbot],
)
submit_btn.click(
chatbot_function,
[text_box, chatbot],
[text_box, chatbot],
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()