github-actions
commited on
Commit
•
7873f3c
1
Parent(s):
d4d8ea9
Sync updates from source repository
Browse files- app.py +84 -5
- query.py +8 -7
- requirements.txt +2 -0
app.py
CHANGED
@@ -1,13 +1,59 @@
|
|
1 |
from omegaconf import OmegaConf
|
2 |
from query import VectaraQuery
|
3 |
import os
|
|
|
|
|
|
|
4 |
|
5 |
import streamlit as st
|
6 |
from streamlit_pills import pills
|
|
|
7 |
|
8 |
from PIL import Image
|
9 |
|
10 |
max_examples = 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def isTrue(x) -> bool:
|
13 |
if isinstance(x, bool):
|
@@ -16,11 +62,11 @@ def isTrue(x) -> bool:
|
|
16 |
|
17 |
def launch_bot():
|
18 |
def generate_response(question):
|
19 |
-
response = vq.submit_query(question)
|
20 |
return response
|
21 |
|
22 |
def generate_streaming_response(question):
|
23 |
-
response = vq.submit_query_streaming(question)
|
24 |
return response
|
25 |
|
26 |
def show_example_questions():
|
@@ -41,11 +87,13 @@ def launch_bot():
|
|
41 |
'source_data_desc': os.environ['source_data_desc'],
|
42 |
'streaming': isTrue(os.environ.get('streaming', False)),
|
43 |
'prompt_name': os.environ.get('prompt_name', None),
|
44 |
-
'examples': os.environ.get('examples', None)
|
|
|
45 |
})
|
46 |
st.session_state.cfg = cfg
|
47 |
st.session_state.ex_prompt = None
|
48 |
-
st.session_state.first_turn = True
|
|
|
49 |
example_messages = [example.strip() for example in cfg.examples.split(",")]
|
50 |
st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
|
51 |
|
@@ -60,7 +108,13 @@ def launch_bot():
|
|
60 |
image = Image.open('Vectara-logo.png')
|
61 |
st.image(image, width=175)
|
62 |
st.markdown(f"## About\n\n"
|
63 |
-
f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
st.markdown("---")
|
66 |
st.markdown(
|
@@ -111,7 +165,32 @@ def launch_bot():
|
|
111 |
st.write(response)
|
112 |
message = {"role": "assistant", "content": response}
|
113 |
st.session_state.messages.append(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
if __name__ == "__main__":
|
117 |
launch_bot()
|
|
|
1 |
from omegaconf import OmegaConf
|
2 |
from query import VectaraQuery
|
3 |
import os
|
4 |
+
import requests
|
5 |
+
import json
|
6 |
+
import uuid
|
7 |
|
8 |
import streamlit as st
|
9 |
from streamlit_pills import pills
|
10 |
+
from streamlit_feedback import streamlit_feedback
|
11 |
|
12 |
from PIL import Image
|
13 |
|
14 |
max_examples = 6
|
15 |
+
languages = {'English': 'eng', 'Spanish': 'spa', 'French': 'frs', 'Chinese': 'zho', 'German': 'deu', 'Hindi': 'hin', 'Arabic': 'ara',
|
16 |
+
'Portuguese': 'por', 'Italian': 'ita', 'Japanese': 'jpn', 'Korean': 'kor', 'Russian': 'rus', 'Turkish': 'tur', 'Persian (Farsi)': 'fas',
|
17 |
+
'Vietnamese': 'vie', 'Thai': 'tha', 'Hebrew': 'heb', 'Dutch': 'nld', 'Indonesian': 'ind', 'Polish': 'pol', 'Ukrainian': 'ukr',
|
18 |
+
'Romanian': 'ron', 'Swedish': 'swe', 'Czech': 'ces', 'Greek': 'ell', 'Bengali': 'ben', 'Malay (or Malaysian)': 'msa', 'Urdu': 'urd'}
|
19 |
+
|
20 |
+
# Setup for HTTP API Calls to Amplitude Analytics
|
21 |
+
if 'device_id' not in st.session_state:
|
22 |
+
st.session_state.device_id = str(uuid.uuid4())
|
23 |
+
|
24 |
+
headers = {
|
25 |
+
'Content-Type': 'application/json',
|
26 |
+
'Accept': '*/*'
|
27 |
+
}
|
28 |
+
amp_api_key = os.getenv('AMPLITUDE_TOKEN')
|
29 |
+
|
30 |
+
def thumbs_feedback(feedback, **kwargs):
|
31 |
+
"""
|
32 |
+
Sends feedback to Amplitude Analytics
|
33 |
+
"""
|
34 |
+
data = {
|
35 |
+
"api_key": amp_api_key,
|
36 |
+
"events": [{
|
37 |
+
"device_id": st.session_state.device_id,
|
38 |
+
"event_type": "provided_feedback",
|
39 |
+
"event_properties": {
|
40 |
+
"Space Name": kwargs.get("title", "Unknown Space Name"),
|
41 |
+
"Demo Type": "chatbot",
|
42 |
+
"query": kwargs.get("prompt", "No user input"),
|
43 |
+
"response": kwargs.get("response", "No chat response"),
|
44 |
+
"feedback": feedback["score"],
|
45 |
+
"Response Language": st.session_state.language
|
46 |
+
}
|
47 |
+
}]
|
48 |
+
}
|
49 |
+
response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
|
50 |
+
if response.status_code != 200:
|
51 |
+
print(f"Request failed with status code {response.status_code}. Response Text: {response.text}")
|
52 |
+
|
53 |
+
st.session_state.feedback_key += 1
|
54 |
+
|
55 |
+
if "feedback_key" not in st.session_state:
|
56 |
+
st.session_state.feedback_key = 0
|
57 |
|
58 |
def isTrue(x) -> bool:
|
59 |
if isinstance(x, bool):
|
|
|
62 |
|
63 |
def launch_bot():
|
64 |
def generate_response(question):
|
65 |
+
response = vq.submit_query(question, languages[st.session_state.language])
|
66 |
return response
|
67 |
|
68 |
def generate_streaming_response(question):
|
69 |
+
response = vq.submit_query_streaming(question, languages[st.session_state.language])
|
70 |
return response
|
71 |
|
72 |
def show_example_questions():
|
|
|
87 |
'source_data_desc': os.environ['source_data_desc'],
|
88 |
'streaming': isTrue(os.environ.get('streaming', False)),
|
89 |
'prompt_name': os.environ.get('prompt_name', None),
|
90 |
+
'examples': os.environ.get('examples', None),
|
91 |
+
'language': 'English'
|
92 |
})
|
93 |
st.session_state.cfg = cfg
|
94 |
st.session_state.ex_prompt = None
|
95 |
+
st.session_state.first_turn = True
|
96 |
+
st.session_state.language = cfg.language
|
97 |
example_messages = [example.strip() for example in cfg.examples.split(",")]
|
98 |
st.session_state.example_messages = [em for em in example_messages if len(em)>0][:max_examples]
|
99 |
|
|
|
108 |
image = Image.open('Vectara-logo.png')
|
109 |
st.image(image, width=175)
|
110 |
st.markdown(f"## About\n\n"
|
111 |
+
f"This demo uses Retrieval Augmented Generation to ask questions about {cfg.source_data_desc}\n")
|
112 |
+
|
113 |
+
cfg.language = st.selectbox('Language:', languages.keys())
|
114 |
+
if st.session_state.language != cfg.language:
|
115 |
+
st.session_state.language = cfg.language
|
116 |
+
print(f"DEBUG: Language changed to {st.session_state.language}")
|
117 |
+
st.rerun()
|
118 |
|
119 |
st.markdown("---")
|
120 |
st.markdown(
|
|
|
165 |
st.write(response)
|
166 |
message = {"role": "assistant", "content": response}
|
167 |
st.session_state.messages.append(message)
|
168 |
+
|
169 |
+
# Send query and response to Amplitude Analytics
|
170 |
+
data = {
|
171 |
+
"api_key": amp_api_key,
|
172 |
+
"events": [{
|
173 |
+
"device_id": st.session_state.device_id,
|
174 |
+
"event_type": "submitted_query",
|
175 |
+
"event_properties": {
|
176 |
+
"Space Name": cfg["title"],
|
177 |
+
"Demo Type": "chatbot",
|
178 |
+
"query": st.session_state.messages[-2]["content"],
|
179 |
+
"response": st.session_state.messages[-1]["content"],
|
180 |
+
"Response Language": st.session_state.language
|
181 |
+
}
|
182 |
+
}]
|
183 |
+
}
|
184 |
+
response = requests.post('https://api2.amplitude.com/2/httpapi', headers=headers, data=json.dumps(data))
|
185 |
+
if response.status_code != 200:
|
186 |
+
print(f"Amplitude request failed with status code {response.status_code}. Response Text: {response.text}")
|
187 |
st.rerun()
|
188 |
+
|
189 |
+
if (st.session_state.messages[-1]["role"] == "assistant") & (st.session_state.messages[-1]["content"] != "How may I help you?"):
|
190 |
+
streamlit_feedback(feedback_type="thumbs", on_submit = thumbs_feedback, key = st.session_state.feedback_key,
|
191 |
+
kwargs = {"prompt": st.session_state.messages[-2]["content"],
|
192 |
+
"response": st.session_state.messages[-1]["content"],
|
193 |
+
"title": cfg["title"]})
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
launch_bot()
|
query.py
CHANGED
@@ -10,7 +10,7 @@ class VectaraQuery():
|
|
10 |
self.conv_id = None
|
11 |
|
12 |
|
13 |
-
def get_body(self, query_str: str, stream: False):
|
14 |
corpora_list = [{
|
15 |
'corpus_key': corpus_key, 'lexical_interpolation': 0.005
|
16 |
} for corpus_key in self.corpus_keys
|
@@ -40,11 +40,12 @@ class VectaraQuery():
|
|
40 |
{
|
41 |
'prompt_name': self.prompt_name,
|
42 |
'max_used_search_results': 10,
|
43 |
-
'response_language':
|
44 |
'citations':
|
45 |
{
|
46 |
'style': 'none'
|
47 |
-
}
|
|
|
48 |
},
|
49 |
'chat':
|
50 |
{
|
@@ -70,14 +71,14 @@ class VectaraQuery():
|
|
70 |
"grpc-timeout": "60S"
|
71 |
}
|
72 |
|
73 |
-
def submit_query(self, query_str: str):
|
74 |
|
75 |
if self.conv_id:
|
76 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
77 |
else:
|
78 |
endpoint = "https://api.vectara.io/v2/chats"
|
79 |
|
80 |
-
body = self.get_body(query_str, stream=False)
|
81 |
|
82 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
|
83 |
|
@@ -96,14 +97,14 @@ class VectaraQuery():
|
|
96 |
|
97 |
return summary
|
98 |
|
99 |
-
def submit_query_streaming(self, query_str: str):
|
100 |
|
101 |
if self.conv_id:
|
102 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
103 |
else:
|
104 |
endpoint = "https://api.vectara.io/v2/chats"
|
105 |
|
106 |
-
body = self.get_body(query_str, stream=True)
|
107 |
|
108 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
|
109 |
|
|
|
10 |
self.conv_id = None
|
11 |
|
12 |
|
13 |
+
def get_body(self, query_str: str, response_lang: str, stream: False):
|
14 |
corpora_list = [{
|
15 |
'corpus_key': corpus_key, 'lexical_interpolation': 0.005
|
16 |
} for corpus_key in self.corpus_keys
|
|
|
40 |
{
|
41 |
'prompt_name': self.prompt_name,
|
42 |
'max_used_search_results': 10,
|
43 |
+
'response_language': response_lang,
|
44 |
'citations':
|
45 |
{
|
46 |
'style': 'none'
|
47 |
+
},
|
48 |
+
'enable_factual_consistency_score': False
|
49 |
},
|
50 |
'chat':
|
51 |
{
|
|
|
71 |
"grpc-timeout": "60S"
|
72 |
}
|
73 |
|
74 |
+
def submit_query(self, query_str: str, language: str):
|
75 |
|
76 |
if self.conv_id:
|
77 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
78 |
else:
|
79 |
endpoint = "https://api.vectara.io/v2/chats"
|
80 |
|
81 |
+
body = self.get_body(query_str, language, stream=False)
|
82 |
|
83 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
|
84 |
|
|
|
97 |
|
98 |
return summary
|
99 |
|
100 |
+
def submit_query_streaming(self, query_str: str, language: str):
|
101 |
|
102 |
if self.conv_id:
|
103 |
endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
|
104 |
else:
|
105 |
endpoint = "https://api.vectara.io/v2/chats"
|
106 |
|
107 |
+
body = self.get_body(query_str, language, stream=True)
|
108 |
|
109 |
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True)
|
110 |
|
requirements.txt
CHANGED
@@ -3,3 +3,5 @@ toml==0.10.2
|
|
3 |
omegaconf==2.3.0
|
4 |
syrupy==4.0.8
|
5 |
streamlit_pills==0.3.0
|
|
|
|
|
|
3 |
omegaconf==2.3.0
|
4 |
syrupy==4.0.8
|
5 |
streamlit_pills==0.3.0
|
6 |
+
streamlit-feedback==0.1.3
|
7 |
+
uuid==1.30
|