File size: 5,683 Bytes
0e6228a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859a762
d0c9bed
0e6228a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a99442
b79eb10
698e298
 
0e6228a
 
3034fdd
e67f916
b79eb10
0e6228a
f21ebd0
0e6228a
e037adc
0e6228a
a4951ab
 
 
0e6228a
d3532d8
f21ebd0
3aae5df
a4951ab
0e6228a
a4951ab
 
 
0e6228a
a4951ab
 
0e6228a
 
8a40227
e80f35b
 
f21ebd0
0e6228a
 
f21ebd0
 
0e6228a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import json
import gradio as gr
import openai

from typing import Iterable
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.agents import load_tools, initialize_agent
from langchain.agents import AgentType
from langchain.tools import Tool
from langchain.utilities import GoogleSearchAPIWrapper


openai.api_key = os.environ['OPENAI_API_KEY']

def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
    with open(file_path, 'w') as jsonl_file:
        for doc in array:
            jsonl_file.write(doc.json() + '\n')

def load_docs_from_jsonl(file_path) -> Iterable[Document]:
    if not os.path.exists(file_path):
        print("Invalid file path.")
        return []
    array = []
    with open(file_path, 'r') as jsonl_file:
        for line in jsonl_file:
            data = json.loads(line)
            obj = Document(**data)
            array.append(obj)
    return array

# Loading all the documents if they are not found locally
documents = load_docs_from_jsonl('striim_docs.jsonl')

# Split the documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500)
docs = text_splitter.split_documents(documents)

# Convert the document chunks to embedding and save them to the vector store
vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings())

# create our Q&A chain
pdf_qa = ConversationalRetrievalChain.from_llm(
    ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'),
    retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}),
    return_generated_question=True,
    return_source_documents=True,
    verbose=False,
)

# Function to query Google if user selects allow internet access
def get_query_from_internet(user_query, temperature=0):
    delimiter = "```"
    # Checking if user query is flagged as inappropriate
    response = openai.Moderation.create(input=user_query["question"])
    moderation_output = response["results"][0]

    if moderation_output["flagged"]:
        return "Your query was flagged as inappropriate. Please try again."

    search = GoogleSearchAPIWrapper()

    tool = Tool(
        name="Google Search",
        description="Search Google for recent results.",
        func=search.run,
    )
    llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
    tools = load_tools(["requests_all"])
    tools += [tool]
    agent_chain = initialize_agent(
        tools,
        llm,
        agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True,
        handle_parsing_errors="Check your output and make sure it conforms!"
    )
    return agent_chain.run({'input': user_query})


# Front end web application using Gradio
CSS ="""
footer.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq { display: none; }
#chatbot { height: 70vh !important;}
#submit-button { background: #00A7E5; color: white; }
#submit-button:hover { background: #00A7E5; color: white; box-shadow: 0 8px 10px 1px #9d9ea124, 0 3px 14px 2px #9d9ea11f, 0 5px 5px -3px #9d9ea133; }
"""

with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo:
    # image = gr.Image('striim-logo-light.png', height=68, width=200, show_label=False, show_download_button=False, show_share_button=False)
    chatbot = gr.Chatbot(show_label=False, elem_id="chatbot")
    msg = gr.Textbox(label="Question:")
    user = gr.State("gradio")
    examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples")
    submit = gr.Button("Submit", elem_id="submit-button")

    #with gr.Accordion(label="Advanced options", open=False):
        #slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.")
        #internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.")

    chat_history = []
    def getResponse(query, history, userId):
        global chat_history
        #if allow_internet:
            # Get response from internet-based query function
         #   result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value)
          #  answer = result
        #else:
            # Get response from QA chain
        result = pdf_qa({"question": query, "chat_history": chat_history})
        answer = result["answer"]

        # Append user message and response to chat history
        chat_history.append((query, answer))
        # Only keeps last 5 messages to not exceed tokens
        chat_history = chat_history[-5:]
        return gr.update(value=""), chat_history, userId

    # The msg.submit() now also depends on the status of the internet_access checkbox
    msg.submit(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False)
    submit.click(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False)

if __name__ == "__main__":
    # demo.launch(debug=True)
    demo.launch(debug=True, share=True)