ItsNikolor's picture
Update app.py
aa31461 verified
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def combine_title_summary(title, summary):
return "title: " + title + " summary: " + summary
tag2ind = {
"Biology": 0,
"Physics": 1,
"Math": 2,
"Computer Science": 3,
}
@st.cache_resource
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dir_name = "./distilbert/distilbert-base-cased/checkpoint-738"
dir_name = "./microsoft/deberta-v3-small/checkpoint-4915"
tokenizer = AutoTokenizer.from_pretrained(dir_name, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(dir_name).to(device)
return tokenizer, model
def run_model(title, summary):
tokenizer, model = load_model()
text = combine_title_summary(title, summary)
tokens_info = tokenizer(
text,
padding=False,
truncation=True,
return_tensors="pt",
max_length=512,
)
model.eval()
model.cpu()
with torch.no_grad():
out = model(**tokens_info)
probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]
ids = torch.argsort(probs, descending=True)
p = 0
best_tags, best_probs = [], []
for ind in ids:
p += probs[ind]
best_tags.append(list(tag2ind.keys())[ind])
best_probs.append(probs[ind])
if p >= 0.95:
break
return best_tags, best_probs
def main():
title = st.text_input(label="Title", value="")
abstract = st.text_area(label="Abstract", value="", height=200)
if st.button("Classify"):
if title == "" and abstract == "":
st.error("At least one of title or abstract must be provided")
else:
best_tags, best_probs = run_model(title, abstract)
df = pd.DataFrame(
dict(zip(best_tags, best_probs)).items(),
columns=["Theme", "Probability"],
)
st.table(df)
if __name__ == "__main__":
main()