Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langchain.chains import ( | |
| ConversationalRetrievalChain, | |
| LLMChain, | |
| MapReduceDocumentsChain, | |
| ReduceDocumentsChain, | |
| StuffDocumentsChain, | |
| ) | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| from langchain_community.chat_models import ChatOpenAI | |
| from langchain_community.document_loaders import WebBaseLoader | |
| def wait_for_summarization(url): | |
| return [(None, f"Please wait while I summarize the contents of {url}...")] | |
| def load_page(url, api_key, history): | |
| global docs, summary, llm | |
| loader = WebBaseLoader(url) | |
| docs = loader.load() | |
| llm = ChatOpenAI( | |
| model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key | |
| ) | |
| map_template = """The following is a set of snippets from a web page: | |
| {docs} | |
| Based on this list of snippets, please identify the main themes | |
| Helpful Answer:""" | |
| map_prompt = PromptTemplate.from_template(map_template) | |
| map_chain = LLMChain(llm=llm, prompt=map_prompt) | |
| # Reduce | |
| reduce_template = """The following is set of summaries of a web page: | |
| {docs} | |
| Take these and distill it into a final, consolidated summary of the main themes. | |
| Helpful Answer:""" | |
| reduce_prompt = PromptTemplate.from_template(reduce_template) | |
| reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) | |
| # Takes a list of documents, combines them into a single string, and passes this to an LLMChain | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, document_variable_name="docs" | |
| ) | |
| # Combines and iteratively reduces the mapped documents | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| # This is final chain that is called. | |
| combine_documents_chain=combine_documents_chain, | |
| # If documents exceed context for `StuffDocumentsChain` | |
| collapse_documents_chain=combine_documents_chain, | |
| # The maximum number of tokens to group documents into. | |
| token_max=4000, | |
| ) | |
| # Combining documents by mapping a chain over them, then combining results | |
| map_reduce_chain = MapReduceDocumentsChain( | |
| # Map chain | |
| llm_chain=map_chain, | |
| # Reduce chain | |
| reduce_documents_chain=reduce_documents_chain, | |
| # The variable name in the llm_chain to put the documents in | |
| document_variable_name="docs", | |
| # Return the results of the map steps in the output | |
| return_intermediate_steps=False, | |
| ) | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| chunk_size=1000, chunk_overlap=0 | |
| ) | |
| split_docs = text_splitter.split_documents(docs) | |
| summary = map_reduce_chain.run(split_docs) | |
| return history + [(None, summary)] | |
| def prepare_chat(api_key, history): | |
| global docs, summary, llm, qa | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128) | |
| documents = text_splitter.split_documents(docs) | |
| embeddings = OpenAIEmbeddings(openai_api_key=api_key) | |
| vectorstore = Chroma.from_documents(documents, embeddings) | |
| retriever = vectorstore.as_retriever( | |
| search_type="similarity", search_kwargs={"k": 6} | |
| ) | |
| qa_prompt_template = ( | |
| """As an AI assistant you help in answering questions about the contents of a web page. | |
| The summary of the current web page is this: | |
| """ | |
| + summary | |
| + """ | |
| Also, consider this additional context that may be relevant for the user's question: | |
| {context} | |
| Please answer following question: {question}""" | |
| ) | |
| qa_prompt = PromptTemplate( | |
| template=qa_prompt_template, input_variables=["context", "question"] | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", return_messages=True, output_key="answer" | |
| ) | |
| qa = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| memory=memory, | |
| retriever=retriever, | |
| combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
| ) | |
| return history + [(None, "You can now ask me specific questions about the page.")] | |
| def chatbot_function(message, history): | |
| global qa | |
| return "", history + [(message, qa.run(message))] | |
| def build_demo(): | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| with gr.Row() as config_row: | |
| with gr.Column(): | |
| api_key_box = gr.Textbox( | |
| show_label=False, | |
| placeholder="OpenAI API Key", | |
| container=False, | |
| autofocus=True, | |
| ) | |
| url_box = gr.Textbox( | |
| show_label=False, | |
| placeholder="URL", | |
| container=False, | |
| ) | |
| load_btn = gr.Button(value="Load", variant="primary") | |
| with gr.Row(visible=False) as chat_row: | |
| with gr.Column(): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="Web Chat", | |
| height=550, | |
| ) | |
| with gr.Row(visible=False) as inputs_row: | |
| with gr.Column(scale=8): | |
| text_box = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and press ENTER", | |
| autofocus=True, | |
| container=False, | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button( | |
| value="Send", | |
| variant="primary", | |
| ) | |
| load_btn.click( | |
| lambda: gr.update(visible=False), | |
| outputs=[config_row], | |
| ).then( | |
| lambda: gr.update(visible=True), | |
| outputs=[chat_row], | |
| ).then( | |
| wait_for_summarization, | |
| inputs=[url_box], | |
| outputs=[chatbot], | |
| ).then( | |
| load_page, | |
| inputs=[url_box, api_key_box, chatbot], | |
| outputs=[chatbot], | |
| ).then( | |
| prepare_chat, | |
| inputs=[api_key_box, chatbot], | |
| outputs=[chatbot], | |
| ).then( | |
| lambda: gr.update(visible=True), | |
| outputs=[inputs_row], | |
| ) | |
| text_box.submit( | |
| chatbot_function, | |
| [text_box, chatbot], | |
| [text_box, chatbot], | |
| ) | |
| submit_btn.click( | |
| chatbot_function, | |
| [text_box, chatbot], | |
| [text_box, chatbot], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |