mtyrrell commited on
Commit
fba6e77
1 Parent(s): 269dd20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -109,7 +109,7 @@ def get_docs(input_query, country = [], vulnerability_cat = []):
109
  filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
110
  else:
111
  filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
112
- docs = retriever.retrieve(query=query, filters = filters, top_k = 10)
113
  # Break out the key fields and convert to pandas for filtering
114
  docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
115
  df_docs = pd.DataFrame(docs)
@@ -154,11 +154,11 @@ def get_refs(docs, res):
154
  return result_str
155
 
156
  # define a special function for putting the prompt together (as we can't use haystack)
157
- def get_prompt(docs, query):
158
  base_prompt=prompt_template
159
  # Add the meta data for references
160
  context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
161
- prompt = base_prompt+"; Context: "+context+"; Question: "+query+"; Answer:"
162
  return(prompt)
163
 
164
  def run_query(input_query, country, model_sel):
@@ -167,13 +167,13 @@ def run_query(input_query, country, model_sel):
167
  # st.write('Selected country: ', country) # Debugging country
168
  if model_sel == "chatGPT":
169
  # res = pipe.run(query=input_text, documents=docs)
170
- res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, query=input_query)}])
171
  output = res["results"][0]
172
  references = get_refs(docs, output)
173
- else:
174
- res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
175
- output = res
176
- references = get_refs(docs, res)
177
  st.write('Response')
178
  st.success(output)
179
  st.write('References')
 
109
  filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
110
  else:
111
  filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
112
+ docs = retriever.retrieve(query=input_query, filters = filters, top_k = 10)
113
  # Break out the key fields and convert to pandas for filtering
114
  docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
115
  df_docs = pd.DataFrame(docs)
 
154
  return result_str
155
 
156
  # define a special function for putting the prompt together (as we can't use haystack)
157
+ def get_prompt(docs, input_query):
158
  base_prompt=prompt_template
159
  # Add the meta data for references
160
  context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
161
+ prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
162
  return(prompt)
163
 
164
  def run_query(input_query, country, model_sel):
 
167
  # st.write('Selected country: ', country) # Debugging country
168
  if model_sel == "chatGPT":
169
  # res = pipe.run(query=input_text, documents=docs)
170
+ res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_query)}])
171
  output = res["results"][0]
172
  references = get_refs(docs, output)
173
+ # else:
174
+ # res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
175
+ # output = res
176
+ # references = get_refs(docs, res)
177
  st.write('Response')
178
  st.success(output)
179
  st.write('References')