Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
import openai | |
from typing import Iterable | |
from langchain.document_loaders import WebBaseLoader | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain.agents import load_tools, initialize_agent | |
from langchain.agents import AgentType | |
from langchain.tools import Tool | |
from langchain.utilities import GoogleSearchAPIWrapper | |
openai.api_key = os.environ['OPENAI_API_KEY'] | |
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None: | |
with open(file_path, 'w') as jsonl_file: | |
for doc in array: | |
jsonl_file.write(doc.json() + '\n') | |
def load_docs_from_jsonl(file_path) -> Iterable[Document]: | |
if not os.path.exists(file_path): | |
print("Invalid file path.") | |
return [] | |
array = [] | |
with open(file_path, 'r') as jsonl_file: | |
for line in jsonl_file: | |
data = json.loads(line) | |
obj = Document(**data) | |
array.append(obj) | |
return array | |
# Loading all the documents if they are not found locally | |
documents = load_docs_from_jsonl('striim_docs.jsonl') | |
# Split the documents into smaller chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500) | |
docs = text_splitter.split_documents(documents) | |
# Convert the document chunks to embedding and save them to the vector store | |
vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings()) | |
# create our Q&A chain | |
pdf_qa = ConversationalRetrievalChain.from_llm( | |
ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'), | |
retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}), | |
return_generated_question=True, | |
return_source_documents=True, | |
verbose=False, | |
) | |
# Function to query Google if user selects allow internet access | |
def get_query_from_internet(user_query, temperature=0): | |
delimiter = "```" | |
# Checking if user query is flagged as inappropriate | |
response = openai.Moderation.create(input=user_query["question"]) | |
moderation_output = response["results"][0] | |
if moderation_output["flagged"]: | |
return "Your query was flagged as inappropriate. Please try again." | |
search = GoogleSearchAPIWrapper() | |
tool = Tool( | |
name="Google Search", | |
description="Search Google for recent results.", | |
func=search.run, | |
) | |
llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo') | |
tools = load_tools(["requests_all"]) | |
tools += [tool] | |
agent_chain = initialize_agent( | |
tools, | |
llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
verbose=True, | |
handle_parsing_errors="Check your output and make sure it conforms!" | |
) | |
return agent_chain.run({'input': user_query}) | |
# Front end web application using Gradio | |
CSS =""" | |
footer.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq { display: none; } | |
#chatbot { height: 70vh !important;} | |
#submit-button { background: #00A7E5; color: white; } | |
#submit-button:hover { background: #00A7E5; color: white; box-shadow: 0 8px 10px 1px #9d9ea124, 0 3px 14px 2px #9d9ea11f, 0 5px 5px -3px #9d9ea133; } | |
""" | |
with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo: | |
# image = gr.Image('striim-logo-light.png', height=68, width=200, show_label=False, show_download_button=False, show_share_button=False) | |
chatbot = gr.Chatbot(show_label=False, elem_id="chatbot") | |
msg = gr.Textbox(label="Question:") | |
user = gr.State("gradio") | |
examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples") | |
submit = gr.Button("Submit", elem_id="submit-button") | |
#with gr.Accordion(label="Advanced options", open=False): | |
#slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.") | |
#internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.") | |
chat_history = [] | |
def getResponse(query, history, userId): | |
global chat_history | |
#if allow_internet: | |
# Get response from internet-based query function | |
# result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value) | |
# answer = result | |
#else: | |
# Get response from QA chain | |
result = pdf_qa({"question": query, "chat_history": chat_history}) | |
answer = result["answer"] | |
# Append user message and response to chat history | |
chat_history.append((query, answer)) | |
# Only keeps last 5 messages to not exceed tokens | |
chat_history = chat_history[-5:] | |
return gr.update(value=""), chat_history, userId | |
# The msg.submit() now also depends on the status of the internet_access checkbox | |
msg.submit(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False) | |
submit.click(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False) | |
if __name__ == "__main__": | |
# demo.launch(debug=True) | |
demo.launch(debug=True, share=True) | |