File size: 3,703 Bytes
106e870
 
 
 
 
 
0c47f30
180cc55
0c47f30
106e870
 
0c47f30
 
f2b6c86
106e870
0c47f30
 
 
180cc55
0c47f30
106e870
0c47f30
106e870
 
 
 
 
 
 
 
 
 
 
180cc55
d5fbae4
106e870
180cc55
 
 
106e870
 
 
 
 
 
 
 
 
 
 
180cc55
 
 
 
 
106e870
 
 
 
 
 
0c47f30
cb2309a
180cc55
cb2309a
180cc55
cb2309a
180cc55
cb2309a
180cc55
 
 
 
cb2309a
180cc55
 
556d513
180cc55
 
 
cb2309a
106e870
 
 
 
cb2309a
106e870
180cc55
106e870
 
cb2309a
106e870
180cc55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model_SingleLabelClassifier import SingleLabelClassifier
from safetensors.torch import load_file
import json
import re


MODEL_NAME = "allenai/scibert_scivocab_uncased"
CHECKPOINT_PATH = "checkpoint-23985"
NUM_CLASSES = 65
MAX_LEN = 250

# Загрузка меток
with open("label_mappings.json", "r") as f:
    mappings = json.load(f)
label2id = mappings["label2id"]
id2label = {int(k): v for k, v in mappings["id2label"].items()}

# Загрузка модели и токенизатора
@st.cache_resource
def load_model_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
    model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
    state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
    model.load_state_dict(state_dict)
    model.eval()
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()


def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
    model.eval()

    title = re.sub(r"\.+$", "", title.strip())
    summary = re.sub(r"\.+$", "", summary.strip())
    text = title + ". " + summary

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_length
    )

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs.get("token_type_ids")
        )
        logits = outputs["logits"]
        probs = F.softmax(logits, dim=1).squeeze().numpy()

    top_indices = probs.argsort()[::-1][:top_k]
    return [(id2label[i], round(probs[i], 3)) for i in top_indices]

# Интерфейс Streamlit
st.title("ArXiv Tag Predictor")

with st.expander("Описание модели"):
    st.markdown("""
    Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для предсказаня первого тега статьей с сайта [arXiv.org](https://arxiv.org).
    
    - Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.), включая категорию other, которая объединяет редкие теги.
    - Модель обучена на **заголовках и аннотациях** научных публикаций
    - На вход принимает **англоязычный текст**
    - Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
    
    Ниже вы можете посмотреть полный список возможных тегов
    """)

with st.expander("Список всех тегов"):
    tag_list = sorted(label2id.keys())
    st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))

st.write("Введите заголовок и аннотацию научной статьи (на английском):")

title = st.text_input("**Title**")
summary = st.text_area("**Summary**", height=200)

if st.button("Предсказать теги"):
    if not title or not summary:
        st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
    else:
        preds = predict(title, summary, model, tokenizer, id2label)
        st.subheader("Предсказанные теги:")
        for tag, prob in preds:
            st.write(f"**{tag}** — вероятность: `{prob:.3f}`")