Spaces:
Running
Running
File size: 3,703 Bytes
106e870 0c47f30 180cc55 0c47f30 106e870 0c47f30 f2b6c86 106e870 0c47f30 180cc55 0c47f30 106e870 0c47f30 106e870 180cc55 d5fbae4 106e870 180cc55 106e870 180cc55 106e870 0c47f30 cb2309a 180cc55 cb2309a 180cc55 cb2309a 180cc55 cb2309a 180cc55 cb2309a 180cc55 556d513 180cc55 cb2309a 106e870 cb2309a 106e870 180cc55 106e870 cb2309a 106e870 180cc55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model_SingleLabelClassifier import SingleLabelClassifier
from safetensors.torch import load_file
import json
import re
MODEL_NAME = "allenai/scibert_scivocab_uncased"
CHECKPOINT_PATH = "checkpoint-23985"
NUM_CLASSES = 65
MAX_LEN = 250
# Загрузка меток
with open("label_mappings.json", "r") as f:
mappings = json.load(f)
label2id = mappings["label2id"]
id2label = {int(k): v for k, v in mappings["id2label"].items()}
# Загрузка модели и токенизатора
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
model = SingleLabelClassifier(MODEL_NAME, num_labels=NUM_CLASSES)
state_dict = load_file(f"{CHECKPOINT_PATH}/model.safetensors")
model.load_state_dict(state_dict)
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def predict(title, summary, model, tokenizer, id2label, max_length=MAX_LEN, top_k=3):
model.eval()
title = re.sub(r"\.+$", "", title.strip())
summary = re.sub(r"\.+$", "", summary.strip())
text = title + ". " + summary
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length
)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs.get("token_type_ids")
)
logits = outputs["logits"]
probs = F.softmax(logits, dim=1).squeeze().numpy()
top_indices = probs.argsort()[::-1][:top_k]
return [(id2label[i], round(probs[i], 3)) for i in top_indices]
# Интерфейс Streamlit
st.title("ArXiv Tag Predictor")
with st.expander("Описание модели"):
st.markdown("""
Данная модель обучена на основе [SciBERT](https://huggingface.co/allenai/scibert_scivocab_uncased) для предсказаня первого тега статьей с сайта [arXiv.org](https://arxiv.org).
- Использует **65 различных тегов** из тематик arXiv (например: `cs.CL`, `math.CO`, `stat.ML`, и т.д.), включая категорию other, которая объединяет редкие теги.
- Модель обучена на **заголовках и аннотациях** научных публикаций
- На вход принимает **англоязычный текст**
- Предсказывает **топ-3 наиболее вероятных тега** для каждой статьи
Ниже вы можете посмотреть полный список возможных тегов
""")
with st.expander("Список всех тегов"):
tag_list = sorted(label2id.keys())
st.markdown("\n".join([f"- `{tag}`" for tag in tag_list]))
st.write("Введите заголовок и аннотацию научной статьи (на английском):")
title = st.text_input("**Title**")
summary = st.text_area("**Summary**", height=200)
if st.button("Предсказать теги"):
if not title or not summary:
st.warning("Пожалуйста, введите и заголовок, и аннотацию!")
else:
preds = predict(title, summary, model, tokenizer, id2label)
st.subheader("Предсказанные теги:")
for tag, prob in preds:
st.write(f"**{tag}** — вероятность: `{prob:.3f}`")
|