theresatvan commited on
Commit
e4296b4
1 Parent(s): ad42ce4

Fix model combination

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  from datasets import load_dataset
3
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
@@ -34,15 +35,20 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
34
 
35
  abstract, claims = input['abstract'], input['claims']
36
 
37
- input_abstract = tokenizer_abstract(abstract, return_tensors='pt')
38
- input_claims = tokenizer_claims(claims, return_tensors='pt')
 
 
 
 
 
39
 
40
  with torch.no_grad():
41
- outputs_abstract = model_abstract(**input_abstract)
42
- outputs_claims = model_claims(**input_claims)
43
 
44
  combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
45
- label = torch.argmax(combined_prob, dim=1)
46
 
47
  return label, combined_prob
48
 
@@ -53,7 +59,7 @@ if __name__ == '__main__':
53
  form = st.form('patent-prediction-form')
54
  dropdown = [example['patent_number'] for example in dataset]
55
 
56
- input_application = form.selectbox('Select a patent\'s application number', patents_dropdown)
57
  submit = form.form_submit_button("Submit")
58
 
59
  if submit:
@@ -62,6 +68,6 @@ if __name__ == '__main__':
62
  label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
63
 
64
  st.write(label)
65
- st.write(predict)
66
  st.write(input['decision'])
67
 
 
1
+ import torch
2
  import streamlit as st
3
  from datasets import load_dataset
4
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
 
35
 
36
  abstract, claims = input['abstract'], input['claims']
37
 
38
+ encoding_abstract = tokenizer_abstract(abstract, return_tensors='pt', truncation=True, padding='max_length')
39
+ encoding_claims = tokenizer_claims(claims, return_tensors='pt', truncation=True, padding='max_length')
40
+
41
+ input_abstract = encoding_abstract['input_ids'].to(device)
42
+ attention_mask_abstract = encoding_abstract['attention_mask'].to(device)
43
+ input_claims = encoding_claims['input_ids'].to(device)
44
+ attention_mask_claims = encoding_claims['attention_mask'].to(device)
45
 
46
  with torch.no_grad():
47
+ outputs_abstract = model_abstract(input_ids=input_abstract, attention_mask=attention_mask_abstract)
48
+ outputs_claims = model_claims(input_ids=input_claims, attention_mask=attention_mask_claims)
49
 
50
  combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
51
+ label = torch.argmax(combined_prob, axis=1).flatten()
52
 
53
  return label, combined_prob
54
 
 
59
  form = st.form('patent-prediction-form')
60
  dropdown = [example['patent_number'] for example in dataset]
61
 
62
+ input_application = form.selectbox('Select a patent\'s application number', dropdown)
63
  submit = form.form_submit_button("Submit")
64
 
65
  if submit:
 
68
  label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
69
 
70
  st.write(label)
71
+ st.write(prob)
72
  st.write(input['decision'])
73