Spaces:
Runtime error
Runtime error
theresatvan
commited on
Commit
•
e4296b4
1
Parent(s):
ad42ce4
Fix model combination
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
from datasets import load_dataset
|
3 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
@@ -34,15 +35,20 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
|
|
34 |
|
35 |
abstract, claims = input['abstract'], input['claims']
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
with torch.no_grad():
|
41 |
-
outputs_abstract = model_abstract(
|
42 |
-
outputs_claims = model_claims(
|
43 |
|
44 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
45 |
-
label = torch.argmax(combined_prob,
|
46 |
|
47 |
return label, combined_prob
|
48 |
|
@@ -53,7 +59,7 @@ if __name__ == '__main__':
|
|
53 |
form = st.form('patent-prediction-form')
|
54 |
dropdown = [example['patent_number'] for example in dataset]
|
55 |
|
56 |
-
input_application = form.selectbox('Select a patent\'s application number',
|
57 |
submit = form.form_submit_button("Submit")
|
58 |
|
59 |
if submit:
|
@@ -62,6 +68,6 @@ if __name__ == '__main__':
|
|
62 |
label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
|
63 |
|
64 |
st.write(label)
|
65 |
-
st.write(
|
66 |
st.write(input['decision'])
|
67 |
|
|
|
1 |
+
import torch
|
2 |
import streamlit as st
|
3 |
from datasets import load_dataset
|
4 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
|
|
35 |
|
36 |
abstract, claims = input['abstract'], input['claims']
|
37 |
|
38 |
+
encoding_abstract = tokenizer_abstract(abstract, return_tensors='pt', truncation=True, padding='max_length')
|
39 |
+
encoding_claims = tokenizer_claims(claims, return_tensors='pt', truncation=True, padding='max_length')
|
40 |
+
|
41 |
+
input_abstract = encoding_abstract['input_ids'].to(device)
|
42 |
+
attention_mask_abstract = encoding_abstract['attention_mask'].to(device)
|
43 |
+
input_claims = encoding_claims['input_ids'].to(device)
|
44 |
+
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
+
outputs_abstract = model_abstract(input_ids=input_abstract, attention_mask=attention_mask_abstract)
|
48 |
+
outputs_claims = model_claims(input_ids=input_claims, attention_mask=attention_mask_claims)
|
49 |
|
50 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
51 |
+
label = torch.argmax(combined_prob, axis=1).flatten()
|
52 |
|
53 |
return label, combined_prob
|
54 |
|
|
|
59 |
form = st.form('patent-prediction-form')
|
60 |
dropdown = [example['patent_number'] for example in dataset]
|
61 |
|
62 |
+
input_application = form.selectbox('Select a patent\'s application number', dropdown)
|
63 |
submit = form.form_submit_button("Submit")
|
64 |
|
65 |
if submit:
|
|
|
68 |
label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
|
69 |
|
70 |
st.write(label)
|
71 |
+
st.write(prob)
|
72 |
st.write(input['decision'])
|
73 |
|