HW4_nlp_arxiv / app.py
DeniSSio's picture
Upload 4 files
6addc6f verified
import streamlit as st
import torch
import numpy as np
from transformers import AutoConfig, AutoModel, AutoTokenizer
from model import DualDistilBERTClassifier # твоя модель
import pandas as pd
import os
# === Загружаем модель и токенизатор ===
# @st.cache_resource
# def load_model():
# topic_labels_s = list(pd.read_json("label_list.json", typ="series"))
# NUM_LABELS = len(topic_labels_s)
# tokenizer_s = AutoTokenizer.from_pretrained("./best_model")
# model_s = DualDistilBERTClassifier("distilbert-base-cased", NUM_LABELS)
# model_s.load_state_dict(torch.load(os.path.join("./best_model", "pytorch_model.bin"), map_location="cpu"))
#
# model_s.eval()
# return topic_labels_s, tokenizer_s, model_s
@st.cache_resource
def load_model():
topic_labels_s = list(pd.read_json("label_list.json", typ="series"))
model = DualDistilBERTClassifier.from_pretrained("DeniSSio/outputs")
tokenizer = AutoTokenizer.from_pretrained("DeniSSio/outputs")
return topic_labels_s, tokenizer, model
topic_labels, tokenizer, model = load_model()
# === Интерфейс ===
st.title("Article Topic Classifier")
st.markdown("Введите **название** статьи и (опционально) **аннотацию**")
title = st.text_input("Title (обязательно)", placeholder="Quantum Entanglement in Neural Networks")
abstract = st.text_area("Abstract (опционально)", placeholder="This paper explores...")
if st.button("Классифицировать") and title.strip():
with st.spinner("Анализируем..."):
# Токенизация
max_length = 256
title_enc = tokenizer(title, truncation=True, padding='max_length',
max_length=max_length, return_tensors='pt')
abstract_enc = tokenizer(abstract or "", truncation=True, padding='max_length',
max_length=max_length, return_tensors='pt')
# Прогон через модель
with torch.no_grad():
outputs = model(
title_input_ids=title_enc['input_ids'],
title_attention_mask=title_enc['attention_mask'],
abstract_input_ids=abstract_enc['input_ids'],
abstract_attention_mask=abstract_enc['attention_mask']
)
logits = outputs['logits']
probs = torch.softmax(logits, dim=1).numpy()[0]
# Создаём DataFrame с результатами
df = pd.DataFrame({'topic': topic_labels, 'prob': probs})
df = df.sort_values('prob', ascending=False).reset_index(drop=True)
df['cum_prob'] = df['prob'].cumsum()
# Берём минимальное число строк, чтобы сумма вероятностей ≥ 95%
cutoff_idx = df[df['cum_prob'] >= 0.95].index[0]
df_filtered = df.iloc[:cutoff_idx + 1]
# Отображаем
st.subheader("Результаты (Softmax):")
for _, row in df_filtered.iterrows():
st.write(f"**{row['topic']}** — вероятность: `{row['prob']:.3f}`")