import json import os from threading import Lock from typing import Any, Dict, Optional, Tuple import gradio as gr from langchain.chains import ConversationalRetrievalChain from langchain.chat_models import ChatOpenAI from langchain.memory import ConversationBufferMemory from langchain.prompts.chat import (ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate) from src.core.chunking import chunk_file from src.core.embedding import embed_files from src.core.parsing import read_file VECTOR_STORE = "faiss" MODEL = "openai" EMBEDDING = "openai" MODEL = "gpt-3.5-turbo-16k" K = 5 USE_VERBOSE = True API_KEY = os.environ["OPENAI_API_KEY"] system_template = """ The context below contains excerpts from 'How to Win Friends & Influence People,' by Dail Carnegie. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with "I'm sorry, but I can't find the answer to your question in, the book How to Win Friends & Influence People.". However, if there is enough information to formulate a response, you must start your response with "Dale says: ". Begin context: {context} End context. {chat_history} """ # Create the chat prompt templates messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}") ] qa_prompt = ChatPromptTemplate.from_messages(messages) class AnswerConversationBufferMemory(ConversationBufferMemory): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: return super(AnswerConversationBufferMemory, self).save_context(inputs,{'response': outputs['answer']}) def getretriever(): with open("./resources/How_To_Win_Friends_And_Influence_People_-_Dale_Carnegie.pdf", 'rb') as uploaded_file: try: file = read_file(uploaded_file) except Exception as e: print(e) chunked_file = chunk_file(file, chunk_size=512, chunk_overlap=0) folder_index = embed_files( files=[chunked_file], embedding=EMBEDDING, vector_store=VECTOR_STORE, openai_api_key=API_KEY, ) return folder_index.index.as_retriever(verbose=True, search_type="similarity", search_kwargs={"k": K}) retriever = getretriever() def predict(message): print(message) msgJson = json.loads(message) print(msgJson) messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}") ] qa_prompt = ChatPromptTemplate.from_messages(messages) llm = ChatOpenAI( openai_api_key=API_KEY, model_name=MODEL, verbose=True) memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True) for msg in msgJson["history"]: memory.save_context({"input": msg[0]}, {"answer": msg[1]}) chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, return_source_documents=USE_VERBOSE, memory=memory, verbose=USE_VERBOSE, combine_docs_chain_kwargs={"prompt": qa_prompt}) chain.rephrase_question = False lock = Lock() lock.acquire() try: output = chain({"question": msgJson["question"]}) output = output["answer"] except Exception as e: print(e) raise e finally: lock.release() return output def getanswer(chain, question, history): if hasattr(chain, "value"): chain = chain.value if hasattr(history, "value"): history = history.value if hasattr(question, "value"): question = question.value history = history or [] lock = Lock() lock.acquire() try: output = chain({"question": question}) output = output["answer"] history.append((question, output)) except Exception as e: raise e finally: lock.release() return history, history, gr.update(value="") def load_chain(inputs = None): llm = ChatOpenAI( openai_api_key=API_KEY, model_name=MODEL, verbose=True) chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, return_source_documents=USE_VERBOSE, memory=AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True), verbose=USE_VERBOSE, combine_docs_chain_kwargs={"prompt": qa_prompt}) return chain with gr.Blocks() as block: with gr.Row(): with gr.Column(scale=0.75): with gr.Row(): gr.Markdown("

How to Win Friends & Influence People

") with gr.Row(): gr.Markdown("by Dale Carnegie") chatbot = gr.Chatbot(elem_id="chatbot").style(height=600) with gr.Row(): message = gr.Textbox( label="", placeholder="How to Win Friends...", lines=1, ) with gr.Row(): submit = gr.Button(value="Send", variant="primary", scale=1) state = gr.State() chain_state = gr.State(load_chain) submit.click(getanswer, inputs=[chain_state, message, state], outputs=[chatbot, state, message]) message.submit(getanswer, inputs=[chain_state, message, state], outputs=[chatbot, state, message]) with gr.Column(scale=0.25): with gr.Row(): gr.Markdown("

Suggestions

") ex1 = gr.Button(value="How do I know if I'm talking about myself too much?", variant="primary") ex1.click(getanswer, inputs=[chain_state, ex1, state], outputs=[chatbot, state, message]) ex2 = gr.Button(value="What do people enjoy talking about the most?", variant="primary") ex2.click(getanswer, inputs=[chain_state, ex2, state], outputs=[chatbot, state, message]) ex4 = gr.Button(value="Why should I try to get along with people better?", variant="primary") ex4.click(getanswer, inputs=[chain_state, ex4, state], outputs=[chatbot, state, message]) ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary") ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message]) predictBtn = gr.Button(value="Predict", visible=False) predictBtn.click(predict, inputs=[message], outputs=[message]) block.launch(debug=True)