Update app.py
Browse files
app.py
CHANGED
@@ -140,33 +140,30 @@ def get_prompt(docs, input_query):
|
|
140 |
return(prompt)
|
141 |
|
142 |
def run_query(input_text, country, model_sel):
|
143 |
-
#
|
144 |
docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
|
145 |
-
#
|
146 |
if model_sel == "chatGPT":
|
147 |
-
|
148 |
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}], stream=True)
|
149 |
-
# iterate through the
|
150 |
-
|
151 |
report = []
|
152 |
for chunk in response:
|
|
|
153 |
chunk_message = chunk['choices'][0]['delta']
|
154 |
-
#
|
155 |
if 'content' in chunk_message:
|
156 |
report.append(chunk_message.content) # extract the message
|
|
|
157 |
result = "".join(report).strip()
|
158 |
-
result
|
159 |
-
# res_box.markdown(f'{result}')
|
160 |
-
res_box.success(result)
|
161 |
|
162 |
-
|
163 |
-
references = get_refs(docs, output)
|
164 |
# else:
|
165 |
-
# res = client.text_generation(
|
166 |
# output = res
|
167 |
# references = get_refs(docs, res)
|
168 |
-
|
169 |
-
# st.success(output)
|
170 |
st.markdown("----")
|
171 |
st.markdown('**REFERENCES:**')
|
172 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|
|
|
140 |
return(prompt)
|
141 |
|
142 |
def run_query(input_text, country, model_sel):
|
143 |
+
# first call the retriever function using selected filters
|
144 |
docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
|
145 |
+
# model selector (not currently being used)
|
146 |
if model_sel == "chatGPT":
|
147 |
+
# instantiate ChatCompletion as a generator object (stream is set to True)
|
148 |
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}], stream=True)
|
149 |
+
# iterate through the streamed output
|
|
|
150 |
report = []
|
151 |
for chunk in response:
|
152 |
+
# extract the object containing the text (totally different structure when streaming)
|
153 |
chunk_message = chunk['choices'][0]['delta']
|
154 |
+
# test to make sure there is text in the object (some don't have)
|
155 |
if 'content' in chunk_message:
|
156 |
report.append(chunk_message.content) # extract the message
|
157 |
+
# add the latest text and merge it with all previous
|
158 |
result = "".join(report).strip()
|
159 |
+
res_box.success(result) # output to response text box
|
|
|
|
|
160 |
|
161 |
+
references = get_refs(docs, result) # extract references from the generated text
|
|
|
162 |
# else:
|
163 |
+
# res = client.text_generation(get_prompt(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
|
164 |
# output = res
|
165 |
# references = get_refs(docs, res)
|
166 |
+
|
|
|
167 |
st.markdown("----")
|
168 |
st.markdown('**REFERENCES:**')
|
169 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|