Everyamans-ai commited on
Commit
b667e96
1 Parent(s): 1d232d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from langchain.agents import load_tools
4
+ from langchain.agents import initialize_agent
5
+ from langchain import PromptTemplate, HuggingFaceHub, LLMChain, ConversationChain
6
+
7
+ from langchain.llms import OpenAI
8
+ from langchain.chains.conversation.memory import ConversationBufferMemory
9
+ from threading import Lock
10
+ import openai
11
+
12
+ from openai.error import AuthenticationError, InvalidRequestError, RateLimitError
13
+ from typing import Optional, Tuple
14
+
15
+ TOOLS_DEFAULT_LIST = ['serpapi', 'news-api', 'pal-math']
16
+ MAX_TOKENS = 512
17
+ PROMPT_TEMPLATE = PromptTemplate(
18
+ input_variables=["original_words"],
19
+ template="Restate the following: \n{original_words}\n",
20
+ )
21
+
22
+ BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!"
23
+ AUTH_ERR_MSG = "Please paste your OpenAI key."
24
+
25
+ news_api_key = os.environ["NEWS_API_KEY"]
26
+
27
+ def run_chain(chain, inp, capture_hidden_text):
28
+ output = ""
29
+ hidden_text = None
30
+ try:
31
+ output = chain.run(input=inp)
32
+ except AuthenticationError as ae:
33
+ output = AUTH_ERR_MSG
34
+ except RateLimitError as rle:
35
+ output = "\n\nRateLimitError: " + str(rle)
36
+ except ValueError as ve:
37
+ output = "\n\nValueError: " + str(ve)
38
+ except InvalidRequestError as ire:
39
+ output = "\n\nInvalidRequestError: " + str(ire)
40
+ except Exception as e:
41
+ output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e)
42
+
43
+ return output, hidden_text
44
+
45
+ def transform_text(desc, express_chain):
46
+
47
+ formatted_prompt = PROMPT_TEMPLATE.format(
48
+ original_words=desc
49
+ )
50
+ generated_text = desc
51
+
52
+ # replace all newlines with <br> in generated_text
53
+ generated_text = generated_text.replace("\n", "\n\n")
54
+
55
+ return generated_text
56
+
57
+ class ChatWrapper:
58
+
59
+ def __init__(self):
60
+ self.lock = Lock()
61
+
62
+ def __call__(
63
+ self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], express_chain: Optional[LLMChain]):
64
+ """Execute the chat functionality."""
65
+ self.lock.acquire()
66
+ try:
67
+ history = history or []
68
+ # If chain is None, that is because no API key was provided.
69
+ output = "Please paste your OpenAI key to use this application."
70
+ hidden_text = output
71
+
72
+ if chain and chain != "":
73
+ # Set OpenAI key
74
+ openai.api_key = api_key
75
+ output, hidden_text = run_chain(chain, inp, capture_hidden_text=False)
76
+ print('output1', output)
77
+
78
+ output = transform_text(output, express_chain)
79
+ print('output2', output)
80
+ text_to_display = output
81
+ history.append((inp, text_to_display))
82
+
83
+ except Exception as e:
84
+ raise e
85
+ finally:
86
+ self.lock.release()
87
+ # return history, history, html_video, temp_file, ""
88
+ return history, history
89
+
90
+
91
+ chat = ChatWrapper()
92
+
93
+ def load_chain(tools_list, llm):
94
+ chain = None
95
+ express_chain = None
96
+ print("\ntools_list", tools_list)
97
+ tool_names = tools_list
98
+ tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key)
99
+
100
+ memory = ConversationBufferMemory(memory_key="chat_history")
101
+
102
+ chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory)
103
+ express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True)
104
+ return chain, express_chain
105
+
106
+
107
+ def set_openai_api_key(api_key):
108
+ """Set the api key and return chain.
109
+ If no api_key, then None is returned.
110
+ """
111
+
112
+ os.environ["OPENAI_API_KEY"] = api_key
113
+ llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS)
114
+ chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm)
115
+ os.environ["OPENAI_API_KEY"] = ""
116
+ return chain, express_chain, llm
117
+
118
+ with gr.Blocks() as app:
119
+ llm_state = gr.State()
120
+ history_state = gr.State()
121
+ chain_state = gr.State()
122
+ express_chain_state = gr.State()
123
+
124
+ with gr.Row():
125
+ with gr.Column():
126
+ gr.HTML(
127
+ """<b><center>GPT + Google</center></b>""")
128
+
129
+ openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)",
130
+ show_label=False, lines=1, type='password')
131
+ with gr.Row():
132
+
133
+ with gr.Column(scale=3):
134
+ chatbot = gr.Chatbot()
135
+ with gr.Row():
136
+ message = gr.Textbox(label="What's on your mind??",
137
+ placeholder="What's the answer to life, the universe, and everything?",
138
+ lines=1)
139
+ submit = gr.Button(value="Send", variant="secondary").style(full_width=False)
140
+
141
+ gr.Examples(
142
+ examples=["How many people live in Canada?",
143
+ "What is 2 to the 30th power?",
144
+ "If x+y=10 and x-y=4, what are x and y?",
145
+ "How much did it rain in SF today?",
146
+ "Get me information about the movie 'Avatar'",
147
+ "What are the top tech headlines in the US?",
148
+ "On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses - "
149
+ "if I remove all the pairs of sunglasses from the desk, how many purple items remain on it?"],
150
+ inputs=message
151
+ )
152
+ message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
153
+ express_chain_state], outputs=[chatbot, history_state])
154
+
155
+ submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state,
156
+ express_chain_state], outputs=[chatbot, history_state])
157
+
158
+ openai_api_key_textbox.change(set_openai_api_key,
159
+ inputs=[openai_api_key_textbox],
160
+ outputs=[chain_state, express_chain_state, llm_state])
161
+
162
+ app.launch(debug=True)