titipata commited on
Commit
e50255f
1 Parent(s): bdfc761

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -8,6 +8,7 @@ tokenizer_name = "allenai/scibert_scivocab_uncased"
8
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
 
11
  def inference(abstract: str):
12
  """
13
  Split an abstract into sentences and perform claim identification.
@@ -20,14 +21,13 @@ def inference(abstract: str):
20
  truncation=True,
21
  padding="longest"
22
  )
23
- output = model(**inputs).logits
24
- for (idx, out) in enumerate(output):
25
- pred = out.argmax().item()
26
- if pred:
27
- claims.append(sents[idx])
28
  if len(claims) > 0:
29
  return ".\n".join(claims)
30
- return "No claims were made here"
 
31
 
32
 
33
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
8
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
9
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
 
11
+
12
  def inference(abstract: str):
13
  """
14
  Split an abstract into sentences and perform claim identification.
 
21
  truncation=True,
22
  padding="longest"
23
  )
24
+ logits = model(**inputs).logits
25
+ preds = logits.argmax(dim=1) # convert logits to predictions
26
+ claims = [sent for sent, pred in zip(sents, preds) if pred == 1]
 
 
27
  if len(claims) > 0:
28
  return ".\n".join(claims)
29
+ else:
30
+ return "No claims found from a given abstract."
31
 
32
 
33
  with gr.Blocks(theme=gr.themes.Soft()) as demo: