import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer from model_SingleLabelClassifier import SingleLabelClassifier from safetensors.torch import load_file import json import re MODEL_NAME = "allenai/scibert_scivocab_uncased" CHECKPOINT_PATH = "checkpoint-23985" NUM_CLASSES = 65 MAX_LEN = 250 # Загрузка меток with open("label_mappings.json", "r") as f: mappings = json.load(f) label2id = mappings["label2id"] id2label = {int(k): v for k, v in mappings["id2label"].items()} # Загрузка модели и токенизатора @st.cache_resource def load_model_and_tokenizer(): tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES) state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors") model.load_state_dict(state_dict) model.eval() return model, tokenizer model, tokenizer = load_model_and_tokenizer() def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3): model.eval() title = re.sub(r"\.+$", "", title.strip()) summary = re.sub(r"\.+$", "", summary.strip()) text = title + ". " + summary inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length ) with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], token_type_ids=inputs.get("token_type_ids") ) logits = outputs["logits"] probs = F.softmax(logits, dim=1).squeeze().numpy() top_indices = probs.argsort()[::-1][:top_k] return [(id2label[i], round(probs[i], 3)) for i in top_indices] # Интерфейс Streamlit st.title("ArXiv Tag Predictor") with st.expander("Описание модели"): st.markdown(""" Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для предсказаня первого тега статьей с сайта [arXiv.org](https://arxiv.org). - Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.), включая категорию other, которая объединяет редкие теги. - Модель обучена на **заголовках и аннотациях** научных публикаций - На вход принимает **англоязычный текст** - Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи Ниже вы можете посмотреть полный список возможных тегов """) with st.expander("Список всех тегов"): tag_list = sorted(label2id.keys()) st.markdown("\n".join([f"- `{tag}`" for tag in tag_list])) st.write("Введите заголовок и аннотацию научной статьи (на английском):") title = st.text_input("**Title**") summary = st.text_area("**Summary**", height=200) if st.button("Предсказать теги"): if not title or not summary: st.warning("Пожалуйста, введите и заголовок, и аннотацию!") else: preds = predict(title, summary, model, tokenizer, id2label) st.subheader("Предсказанные теги:") for tag, prob in preds: st.write(f"**{tag}** — вероятность: `{prob:.3f}`")