File size: 2,074 Bytes
0efe602 78f93e5 fffa35e 78f93e5 2083f73 78f93e5 2083f73 0efe602 2083f73 78f93e5 b8ff6ff 78f93e5 0efe602 78f93e5 b8ff6ff 78f93e5 aa31461 b8ff6ff 78f93e5 b8ff6ff aa31461 b8ff6ff 78f93e5 b8ff6ff 78f93e5 0efe602 78f93e5 0efe602 78f93e5 0efe602 78f93e5 0efe602 aa31461 0efe602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def combine_title_summary(title, summary):
return "title: " + title + " summary: " + summary
tag2ind = {
"Biology": 0,
"Physics": 1,
"Math": 2,
"Computer Science": 3,
}
@st.cache_resource
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dir_name = "./distilbert/distilbert-base-cased/checkpoint-738"
dir_name = "./microsoft/deberta-v3-small/checkpoint-4915"
tokenizer = AutoTokenizer.from_pretrained(dir_name, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(dir_name).to(device)
return tokenizer, model
def run_model(title, summary):
tokenizer, model = load_model()
text = combine_title_summary(title, summary)
tokens_info = tokenizer(
text,
padding=False,
truncation=True,
return_tensors="pt",
max_length=512,
)
model.eval()
model.cpu()
with torch.no_grad():
out = model(**tokens_info)
probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]
ids = torch.argsort(probs, descending=True)
p = 0
best_tags, best_probs = [], []
for ind in ids:
p += probs[ind]
best_tags.append(list(tag2ind.keys())[ind])
best_probs.append(probs[ind])
if p >= 0.95:
break
return best_tags, best_probs
def main():
title = st.text_input(label="Title", value="")
abstract = st.text_area(label="Abstract", value="", height=200)
if st.button("Classify"):
if title == "" and abstract == "":
st.error("At least one of title or abstract must be provided")
else:
best_tags, best_probs = run_model(title, abstract)
df = pd.DataFrame(
dict(zip(best_tags, best_probs)).items(),
columns=["Theme", "Probability"],
)
st.table(df)
if __name__ == "__main__":
main()
|