hw4 / app.py
MikhailPugachev's picture
Исправлен путь к модели
556d513
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model_SingleLabelClassifier import SingleLabelClassifier
from safetensors.torch import load_file
import json
import re
MODEL_NAME = "allenai/scibert_scivocab_uncased"
CHECKPOINT_PATH = "checkpoint-23985"
NUM_CLASSES = 65
MAX_LEN = 250
# Загрузка меток
with open("label_mappings.json", "r") as f:
mappings = json.load(f)
label2id = mappings["label2id"]
id2label = {int(k): v for k, v in mappings["id2label"].items()}
# Загрузка модели и токенизатора
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
model.load_state_dict(state_dict)
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
model.eval()
title = re.sub(r"\.+$", "", title.strip())
summary = re.sub(r"\.+$", "", summary.strip())
text = title + ". " + summary
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length
)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs.get("token_type_ids")
)
logits = outputs["logits"]
probs = F.softmax(logits, dim=1).squeeze().numpy()
top_indices = probs.argsort()[::-1][:top_k]
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
# Интерфейс Streamlit
st.title("ArXiv Tag Predictor")
with st.expander("Описание модели"):
st.markdown("""
Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для предсказаня первого тега статьей с сайта [arXiv.org](https://arxiv.org).
- Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.), включая категорию other, которая объединяет редкие теги.
- Модель обучена на **заголовках и аннотациях** научных публикаций
- На вход принимает **англоязычный текст**
- Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
Ниже вы можете посмотреть полный список возможных тегов
""")
with st.expander("Список всех тегов"):
tag_list = sorted(label2id.keys())
st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))
st.write("Введите заголовок и аннотацию научной статьи (на английском):")
title = st.text_input("**Title**")
summary = st.text_area("**Summary**", height=200)
if st.button("Предсказать теги"):
if not title or not summary:
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
else:
preds = predict(title, summary, model, tokenizer, id2label)
st.subheader("Предсказанные теги:")
for tag, prob in preds:
st.write(f"**{tag}** — вероятность: `{prob:.3f}`")