Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ tokenizer_name = "allenai/scibert_scivocab_uncased"
|
|
8 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
10 |
|
|
|
11 |
def inference(abstract: str):
|
12 |
"""
|
13 |
Split an abstract into sentences and perform claim identification.
|
@@ -20,14 +21,13 @@ def inference(abstract: str):
|
|
20 |
truncation=True,
|
21 |
padding="longest"
|
22 |
)
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
if pred:
|
27 |
-
claims.append(sents[idx])
|
28 |
if len(claims) > 0:
|
29 |
return ".\n".join(claims)
|
30 |
-
|
|
|
31 |
|
32 |
|
33 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
8 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
9 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
10 |
|
11 |
+
|
12 |
def inference(abstract: str):
|
13 |
"""
|
14 |
Split an abstract into sentences and perform claim identification.
|
|
|
21 |
truncation=True,
|
22 |
padding="longest"
|
23 |
)
|
24 |
+
logits = model(**inputs).logits
|
25 |
+
preds = logits.argmax(dim=1) # convert logits to predictions
|
26 |
+
claims = [sent for sent, pred in zip(sents, preds) if pred == 1]
|
|
|
|
|
27 |
if len(claims) > 0:
|
28 |
return ".\n".join(claims)
|
29 |
+
else:
|
30 |
+
return "No claims found from a given abstract."
|
31 |
|
32 |
|
33 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|