mtyrrell commited on
Commit
4b85076
1 Parent(s): d4a78d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -7
app.py CHANGED
@@ -144,17 +144,31 @@ def run_query(input_text, country, model_sel):
144
  docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
145
  # st.write('Selected country: ', country) # Debugging country
146
  if model_sel == "chatGPT":
147
- # res = pipe.run(query=input_text, documents=docs)
148
- res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}])
149
- output = res.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  references = get_refs(docs, output)
151
  # else:
152
  # res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
153
  # output = res
154
  # references = get_refs(docs, res)
155
- st.write('Response')
156
- st.success(output)
157
- st.write('References')
 
158
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
159
  st.markdown(references, unsafe_allow_html=True)
160
 
@@ -225,7 +239,9 @@ else:
225
  text = st.text_area('Enter your question in the text box below using natural language or select an example from above:', value=selected_example)
226
 
227
  if st.button('Submit'):
 
 
 
228
  run_query(text, country=country, model_sel=model_sel)
229
 
230
 
231
-
 
144
  docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
145
  # st.write('Selected country: ', country) # Debugging country
146
  if model_sel == "chatGPT":
147
+
148
+ response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}], stream=True)
149
+ # iterate through the stream of events
150
+
151
+ report = []
152
+ for chunk in response:
153
+ chunk_message = chunk['choices'][0]['delta']
154
+ # collected_chunks.append(chunk) # save the event response
155
+ if 'content' in chunk_message:
156
+ report.append(chunk_message.content) # extract the message
157
+ result = "".join(report).strip()
158
+ result = result.replace("\n", "")
159
+ # res_box.markdown(f'{result}')
160
+ res_box.success(result)
161
+
162
+ output = result
163
  references = get_refs(docs, output)
164
  # else:
165
  # res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
166
  # output = res
167
  # references = get_refs(docs, res)
168
+ # st.write('Response')
169
+ # st.success(output)
170
+ st.markdown("----")
171
+ st.markdown('**REFERENCES:**')
172
  st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
173
  st.markdown(references, unsafe_allow_html=True)
174
 
 
239
  text = st.text_area('Enter your question in the text box below using natural language or select an example from above:', value=selected_example)
240
 
241
  if st.button('Submit'):
242
+ st.markdown("----")
243
+ st.markdown('**RESPONSE:**')
244
+ res_box = st.empty()
245
  run_query(text, country=country, model_sel=model_sel)
246
 
247