Mirageinv's picture
Update app.py
3805e6f verified
import json
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, DistilBertForSequenceClassification
CHECKPOINT_PATH = "checkpoints/epoch_8.pt"
LABELS_PATH = "checkpoints/labels_info.json"
with open(LABELS_PATH, 'r') as f:
LABELS = json.load(f)
print(len(LABELS))
BASE_MODEL = "distilbert-base-cased"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# The same model
model = DistilBertForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=len(LABELS))
state_dict = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
st.title("Классификатор научных статей по заголовку и описанию")
st.write("Введите название и аннотацию статьи для предсказания её тематики по таксономии arxiv.org")
title = st.text_input("Название статьи:")
abstract = st.text_area("Аннотация (abstract):")
if st.button("Классифицировать"):
if not title and not abstract:
st.warning("Введите хотя бы название статьи.")
else:
text = title if not abstract else f"{title} {abstract}"
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256)
with torch.no_grad():
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1).squeeze()
label_probs = [(label, prob.item()) for label, prob in zip(list(LABELS.values()), probs)]
# Sorting for getting 95% afterwards
label_probs.sort(key=lambda x: x[1], reverse=True)
cumulative = 0.0
top_labels = []
for label, prob in label_probs:
cumulative += prob
top_labels.append((label, prob))
if cumulative >= 0.95:
break
# Вывод
st.subheader("Наиболее вероятные тематики (суммарно ≥95%):")
for label, prob in top_labels:
st.write(f"**{label}** — {prob * 100:.2f}%")