Spaces:
Paused
Paused
debatelab-admin
commited on
Commit
•
c493b3a
1
Parent(s):
148cf69
session state
Browse files
app.py
CHANGED
@@ -86,31 +86,6 @@ CACHE_SIZE = 10000
|
|
86 |
def params(config):
|
87 |
pass
|
88 |
|
89 |
-
|
90 |
-
def build_inference_api():
|
91 |
-
"""HF inference api"""
|
92 |
-
API_URL = "https://api-inference.huggingface.co/models/debatelab/argument-analyst"
|
93 |
-
headers = {} # {"Authorization": f"Bearer {st.secrets['api_token']}"}
|
94 |
-
|
95 |
-
def query(inputs: str, parameters):
|
96 |
-
payload = {
|
97 |
-
"inputs": inputs,
|
98 |
-
"parameters": parameters,
|
99 |
-
"options": {"wait_for_model": True},
|
100 |
-
}
|
101 |
-
data = json.dumps(payload)
|
102 |
-
response = requests.request("POST", API_URL, headers=headers, data=data)
|
103 |
-
content = response.content.decode("utf-8")
|
104 |
-
try:
|
105 |
-
# as json
|
106 |
-
result_json = json.loads(content)
|
107 |
-
except Exception:
|
108 |
-
result_json = {"error": content}
|
109 |
-
|
110 |
-
return result_json
|
111 |
-
|
112 |
-
return query
|
113 |
-
|
114 |
|
115 |
@st.cache(allow_output_mutation=True)
|
116 |
def aaac_fields():
|
@@ -271,10 +246,9 @@ def run_model(mode_set, user_input):
|
|
271 |
:returns: output dict
|
272 |
"""
|
273 |
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
inference = pipeline(task="text2text-generation", model=MODEL)
|
278 |
|
279 |
current_input = user_input.copy()
|
280 |
output = []
|
@@ -289,26 +263,7 @@ def run_model(mode_set, user_input):
|
|
289 |
inquire_prompt = inquire_prompt + (f"{to_key}: {from_key}: {current_input[from_key]}")
|
290 |
# inquire model
|
291 |
inputs = inquire_prompt
|
292 |
-
|
293 |
-
out = None
|
294 |
-
while not out and attempts<MAX_API_CALLS:
|
295 |
-
attempts += 1
|
296 |
-
try:
|
297 |
-
# api call
|
298 |
-
out = inference(inputs, **INFERENCE_PARAMS)
|
299 |
-
if not isinstance(out, list):
|
300 |
-
raise ValueError('Response is not a list.')
|
301 |
-
except Exception:
|
302 |
-
if attempts < MAX_API_CALLS:
|
303 |
-
st.warning(
|
304 |
-
f"HF Inference API call (attempt {attempts} of {MAX_API_CALLS}) has failed. Response: {out}. Trying again..."
|
305 |
-
)
|
306 |
-
out = None
|
307 |
-
else:
|
308 |
-
st.warning(
|
309 |
-
f"HF Inference API call (attempt {attempts} of {MAX_API_CALLS}) has failed. Response: {out}. Stopping."
|
310 |
-
)
|
311 |
-
return None
|
312 |
out = out[0]['generated_text']
|
313 |
# cleanup formalization
|
314 |
if to_key in ['premises_formalized','conclusion_formalized']:
|
@@ -327,9 +282,7 @@ def run_model(mode_set, user_input):
|
|
327 |
current_input[to_key] = out
|
328 |
|
329 |
return output
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
|
334 |
def main():
|
335 |
|
|
|
86 |
def params(config):
|
87 |
pass
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
@st.cache(allow_output_mutation=True)
|
91 |
def aaac_fields():
|
|
|
246 |
:returns: output dict
|
247 |
"""
|
248 |
|
249 |
+
if "inference" not in st.session_state:
|
250 |
+
with st.spinner('Initializing pipeline'):
|
251 |
+
st.session_state.inference = pipeline(task="text2text-generation", model=MODEL)
|
|
|
252 |
|
253 |
current_input = user_input.copy()
|
254 |
output = []
|
|
|
263 |
inquire_prompt = inquire_prompt + (f"{to_key}: {from_key}: {current_input[from_key]}")
|
264 |
# inquire model
|
265 |
inputs = inquire_prompt
|
266 |
+
out = st.session_state.inference(inputs, **INFERENCE_PARAMS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
out = out[0]['generated_text']
|
268 |
# cleanup formalization
|
269 |
if to_key in ['premises_formalized','conclusion_formalized']:
|
|
|
282 |
current_input[to_key] = out
|
283 |
|
284 |
return output
|
285 |
+
|
|
|
|
|
286 |
|
287 |
def main():
|
288 |
|