ggbetz commited on
Commit
c493b3a
1 Parent(s): 148cf69

session state

Browse files
Files changed (1) hide show
  1. app.py +5 -52
app.py CHANGED
@@ -86,31 +86,6 @@ CACHE_SIZE = 10000
86
  def params(config):
87
  pass
88
 
89
-
90
- def build_inference_api():
91
- """HF inference api"""
92
- API_URL = "https://api-inference.huggingface.co/models/debatelab/argument-analyst"
93
- headers = {} # {"Authorization": f"Bearer {st.secrets['api_token']}"}
94
-
95
- def query(inputs: str, parameters):
96
- payload = {
97
- "inputs": inputs,
98
- "parameters": parameters,
99
- "options": {"wait_for_model": True},
100
- }
101
- data = json.dumps(payload)
102
- response = requests.request("POST", API_URL, headers=headers, data=data)
103
- content = response.content.decode("utf-8")
104
- try:
105
- # as json
106
- result_json = json.loads(content)
107
- except Exception:
108
- result_json = {"error": content}
109
-
110
- return result_json
111
-
112
- return query
113
-
114
 
115
  @st.cache(allow_output_mutation=True)
116
  def aaac_fields():
@@ -271,10 +246,9 @@ def run_model(mode_set, user_input):
271
  :returns: output dict
272
  """
273
 
274
-
275
- #inference = build_inference_api()
276
- with st.spinner('Initializing pipeline'):
277
- inference = pipeline(task="text2text-generation", model=MODEL)
278
 
279
  current_input = user_input.copy()
280
  output = []
@@ -289,26 +263,7 @@ def run_model(mode_set, user_input):
289
  inquire_prompt = inquire_prompt + (f"{to_key}: {from_key}: {current_input[from_key]}")
290
  # inquire model
291
  inputs = inquire_prompt
292
- attempts = 0
293
- out = None
294
- while not out and attempts<MAX_API_CALLS:
295
- attempts += 1
296
- try:
297
- # api call
298
- out = inference(inputs, **INFERENCE_PARAMS)
299
- if not isinstance(out, list):
300
- raise ValueError('Response is not a list.')
301
- except Exception:
302
- if attempts < MAX_API_CALLS:
303
- st.warning(
304
- f"HF Inference API call (attempt {attempts} of {MAX_API_CALLS}) has failed. Response: {out}. Trying again..."
305
- )
306
- out = None
307
- else:
308
- st.warning(
309
- f"HF Inference API call (attempt {attempts} of {MAX_API_CALLS}) has failed. Response: {out}. Stopping."
310
- )
311
- return None
312
  out = out[0]['generated_text']
313
  # cleanup formalization
314
  if to_key in ['premises_formalized','conclusion_formalized']:
@@ -327,9 +282,7 @@ def run_model(mode_set, user_input):
327
  current_input[to_key] = out
328
 
329
  return output
330
-
331
-
332
-
333
 
334
  def main():
335
 
86
  def params(config):
87
  pass
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  @st.cache(allow_output_mutation=True)
91
  def aaac_fields():
246
  :returns: output dict
247
  """
248
 
249
+ if "inference" not in st.session_state:
250
+ with st.spinner('Initializing pipeline'):
251
+ st.session_state.inference = pipeline(task="text2text-generation", model=MODEL)
 
252
 
253
  current_input = user_input.copy()
254
  output = []
263
  inquire_prompt = inquire_prompt + (f"{to_key}: {from_key}: {current_input[from_key]}")
264
  # inquire model
265
  inputs = inquire_prompt
266
+ out = st.session_state.inference(inputs, **INFERENCE_PARAMS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  out = out[0]['generated_text']
268
  # cleanup formalization
269
  if to_key in ['premises_formalized','conclusion_formalized']:
282
  current_input[to_key] = out
283
 
284
  return output
285
+
 
 
286
 
287
  def main():
288