striim-gpt / app.py
str-platformAI's picture
Adds userId to each query (#11)
f21ebd0
import os
import json
import gradio as gr
import openai
from typing import Iterable
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.agents import load_tools, initialize_agent
from langchain.agents import AgentType
from langchain.tools import Tool
from langchain.utilities import GoogleSearchAPIWrapper
openai.api_key = os.environ['OPENAI_API_KEY']
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
with open(file_path, 'w') as jsonl_file:
for doc in array:
jsonl_file.write(doc.json() + '\n')
def load_docs_from_jsonl(file_path) -> Iterable[Document]:
if not os.path.exists(file_path):
print("Invalid file path.")
return []
array = []
with open(file_path, 'r') as jsonl_file:
for line in jsonl_file:
data = json.loads(line)
obj = Document(**data)
array.append(obj)
return array
# Loading all the documents if they are not found locally
documents = load_docs_from_jsonl('striim_docs.jsonl')
# Split the documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500)
docs = text_splitter.split_documents(documents)
# Convert the document chunks to embedding and save them to the vector store
vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings())
# create our Q&A chain
pdf_qa = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'),
retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}),
return_generated_question=True,
return_source_documents=True,
verbose=False,
)
# Function to query Google if user selects allow internet access
def get_query_from_internet(user_query, temperature=0):
delimiter = "```"
# Checking if user query is flagged as inappropriate
response = openai.Moderation.create(input=user_query["question"])
moderation_output = response["results"][0]
if moderation_output["flagged"]:
return "Your query was flagged as inappropriate. Please try again."
search = GoogleSearchAPIWrapper()
tool = Tool(
name="Google Search",
description="Search Google for recent results.",
func=search.run,
)
llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
tools = load_tools(["requests_all"])
tools += [tool]
agent_chain = initialize_agent(
tools,
llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors="Check your output and make sure it conforms!"
)
return agent_chain.run({'input': user_query})
# Front end web application using Gradio
CSS ="""
footer.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq { display: none; }
#chatbot { height: 70vh !important;}
#submit-button { background: #00A7E5; color: white; }
#submit-button:hover { background: #00A7E5; color: white; box-shadow: 0 8px 10px 1px #9d9ea124, 0 3px 14px 2px #9d9ea11f, 0 5px 5px -3px #9d9ea133; }
"""
with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo:
# image = gr.Image('striim-logo-light.png', height=68, width=200, show_label=False, show_download_button=False, show_share_button=False)
chatbot = gr.Chatbot(show_label=False, elem_id="chatbot")
msg = gr.Textbox(label="Question:")
user = gr.State("gradio")
examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples")
submit = gr.Button("Submit", elem_id="submit-button")
#with gr.Accordion(label="Advanced options", open=False):
#slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.")
#internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.")
chat_history = []
def getResponse(query, history, userId):
global chat_history
#if allow_internet:
# Get response from internet-based query function
# result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value)
# answer = result
#else:
# Get response from QA chain
result = pdf_qa({"question": query, "chat_history": chat_history})
answer = result["answer"]
# Append user message and response to chat history
chat_history.append((query, answer))
# Only keeps last 5 messages to not exceed tokens
chat_history = chat_history[-5:]
return gr.update(value=""), chat_history, userId
# The msg.submit() now also depends on the status of the internet_access checkbox
msg.submit(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False)
submit.click(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False)
if __name__ == "__main__":
# demo.launch(debug=True)
demo.launch(debug=True, share=True)