File size: 2,074 Bytes
0efe602
78f93e5
fffa35e
 
78f93e5
 
2083f73
 
78f93e5
 
2083f73
0efe602
 
 
 
2083f73
78f93e5
 
b8ff6ff
 
 
78f93e5
0efe602
 
 
 
78f93e5
b8ff6ff
78f93e5
 
 
 
aa31461
 
 
 
b8ff6ff
78f93e5
b8ff6ff
 
 
 
 
aa31461
b8ff6ff
78f93e5
b8ff6ff
 
 
 
 
78f93e5
0efe602
 
 
 
 
78f93e5
0efe602
 
78f93e5
0efe602
 
78f93e5
0efe602
 
 
 
 
 
 
 
 
 
aa31461
0efe602
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def combine_title_summary(title, summary):
    return "title: " + title + " summary: " + summary


tag2ind = {
    "Biology": 0,
    "Physics": 1,
    "Math": 2,
    "Computer Science": 3,
}


@st.cache_resource
def load_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # dir_name = "./distilbert/distilbert-base-cased/checkpoint-738"
    dir_name = "./microsoft/deberta-v3-small/checkpoint-4915"
    tokenizer = AutoTokenizer.from_pretrained(dir_name, use_fast=False)
    model = AutoModelForSequenceClassification.from_pretrained(dir_name).to(device)

    return tokenizer, model





def run_model(title, summary):
    tokenizer, model = load_model()
    
    text = combine_title_summary(title, summary)

    tokens_info = tokenizer(
        text,
        padding=False,
        truncation=True,
        return_tensors="pt",
        max_length=512,
    )

    model.eval()
    model.cpu()
    with torch.no_grad():
        out = model(**tokens_info)
        probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]

    ids = torch.argsort(probs, descending=True)
    p = 0
    best_tags, best_probs = [], []
    for ind in ids:
        p += probs[ind]

        best_tags.append(list(tag2ind.keys())[ind])
        best_probs.append(probs[ind])

        if p >= 0.95:
            break

    return best_tags, best_probs


def main():
    title = st.text_input(label="Title", value="")
    abstract = st.text_area(label="Abstract", value="", height=200)
    if st.button("Classify"):
        if title == "" and abstract == "":
            st.error("At least one of title or abstract must be provided")
        else:
            best_tags, best_probs = run_model(title, abstract)

            df = pd.DataFrame(
                dict(zip(best_tags, best_probs)).items(),
                columns=["Theme", "Probability"],
            )
            st.table(df)


if __name__ == "__main__":
    main()