Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer | |
from model_SingleLabelClassifier import SingleLabelClassifier | |
from safetensors.torch import load_file | |
import json | |
import re | |
MODEL_NAME = "allenai/scibert_scivocab_uncased" | |
CHECKPOINT_PATH = "checkpoint-23985" | |
NUM_CLASSES = 65 | |
MAX_LEN = 250 | |
# Загрузка меток | |
with open("label_mappings.json", "r") as f: | |
mappings = json.load(f) | |
label2id = mappings["label2id"] | |
id2label = {int(k): v for k, v in mappings["id2label"].items()} | |
# Загрузка модели и токенизатора | |
def load_model_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) | |
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES) | |
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors") | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model, tokenizer | |
model, tokenizer = load_model_and_tokenizer() | |
def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3): | |
model.eval() | |
title = re.sub(r"\.+$", "", title.strip()) | |
summary = re.sub(r"\.+$", "", summary.strip()) | |
text = title + ". " + summary | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding="max_length", | |
max_length=max_length | |
) | |
with torch.no_grad(): | |
outputs = model( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
token_type_ids=inputs.get("token_type_ids") | |
) | |
logits = outputs["logits"] | |
probs = F.softmax(logits, dim=1).squeeze().numpy() | |
top_indices = probs.argsort()[::-1][:top_k] | |
return [(id2label[i], round(probs[i], 3)) for i in top_indices] | |
# Интерфейс Streamlit | |
st.title("ArXiv Tag Predictor") | |
with st.expander("Описание модели"): | |
st.markdown(""" | |
Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для предсказаня первого тега статьей с сайта [arXiv.org](https://arxiv.org). | |
- Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.), включая категорию other, которая объединяет редкие теги. | |
- Модель обучена на **заголовках и аннотациях** научных публикаций | |
- На вход принимает **англоязычный текст** | |
- Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи | |
Ниже вы можете посмотреть полный список возможных тегов | |
""") | |
with st.expander("Список всех тегов"): | |
tag_list = sorted(label2id.keys()) | |
st.markdown("\n".join([f"- `{tag}`" for tag in tag_list])) | |
st.write("Введите заголовок и аннотацию научной статьи (на английском):") | |
title = st.text_input("**Title**") | |
summary = st.text_area("**Summary**", height=200) | |
if st.button("Предсказать теги"): | |
if not title or not summary: | |
st.warning("Пожалуйста, введите и заголовок, и аннотацию!") | |
else: | |
preds = predict(title, summary, model, tokenizer, id2label) | |
st.subheader("Предсказанные теги:") | |
for tag, prob in preds: | |
st.write(f"**{tag}** — вероятность: `{prob:.3f}`") | |