Spaces:
Running
Running
import json | |
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, DistilBertForSequenceClassification | |
CHECKPOINT_PATH = "checkpoints/epoch_8.pt" | |
LABELS_PATH = "checkpoints/labels_info.json" | |
with open(LABELS_PATH, 'r') as f: | |
LABELS = json.load(f) | |
print(len(LABELS)) | |
BASE_MODEL = "distilbert-base-cased" | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
# The same model | |
model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(LABELS)) | |
state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu')) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return tokenizer, model | |
tokenizer, model = load_model() | |
st.title("Классификатор научных статей по заголовку и описанию") | |
st.write("Введите название и аннотацию статьи для предсказания её тематики по таксономии arxiv.org") | |
title = st.text_input("Название статьи:") | |
abstract = st.text_area("Аннотация (abstract):") | |
if st.button("Классифицировать"): | |
if not title and not abstract: | |
st.warning("Введите хотя бы название статьи.") | |
else: | |
text = title if not abstract else f"{title} {abstract}" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = F.softmax(outputs.logits, dim=1).squeeze() | |
label_probs = [(label, prob.item()) for label, prob in zip(list(LABELS.values()), probs)] | |
# Sorting for getting 95% afterwards | |
label_probs.sort(key=lambda x: x[1], reverse=True) | |
cumulative = 0.0 | |
top_labels = [] | |
for label, prob in label_probs: | |
cumulative += prob | |
top_labels.append((label, prob)) | |
if cumulative >= 0.95: | |
break | |
# Вывод | |
st.subheader("Наиболее вероятные тематики (суммарно ≥95%):") | |
for label, prob in top_labels: | |
st.write(f"**{label}** — {prob * 100:.2f}%") | |