|
import streamlit as st |
|
import torch |
|
from torch import nn |
|
import csv |
|
from transformers import AutoModel, AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
from model import ClassificationModel |
|
|
|
st.set_page_config(page_title="Article Theme Classifier", layout="centered") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
MAX_LENGTH = 512 |
|
|
|
@st.cache_resource |
|
def get_model(): |
|
base_model = AutoModel.from_pretrained("distilbert-base-cased") |
|
class_model = ClassificationModel(base_model) |
|
|
|
weights_path = hf_hub_download( |
|
repo_id="MostoHF/TunedDistillBertCased", |
|
filename="pytorch_model.bin" |
|
) |
|
|
|
state_dict = torch.load(weights_path, map_location=device) |
|
class_model.load_state_dict(state_dict) |
|
class_model.to(device) |
|
class_model.eval() |
|
|
|
return class_model |
|
|
|
@st.cache_resource |
|
def get_tokenizer(): |
|
return AutoTokenizer.from_pretrained("distilbert-base-cased") |
|
|
|
@st.cache_resource |
|
def get_ind_to_cat(): |
|
ind_to_category_copy = {} |
|
with open('ind_to_category.csv', mode='r', newline='') as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
for key, value in reader: |
|
ind_to_category_copy[int(key)] = value |
|
return ind_to_category_copy |
|
|
|
class_model = get_model() |
|
tokenizer = get_tokenizer() |
|
ind_to_category = get_ind_to_cat() |
|
|
|
def inference(title, abstract, threshold=0.95): |
|
cur_elem = title + '@' + abstract |
|
|
|
encoding = tokenizer(cur_elem, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt") |
|
input_ids = encoding["input_ids"].to(device) |
|
attention_mask = encoding["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
res_probs = torch.exp(class_model(input_ids, attention_mask)) |
|
|
|
probs = res_probs.squeeze(0) |
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
|
total = 0.0 |
|
selected_indices = [] |
|
selected_probs = [] |
|
|
|
for prob, idx in zip(sorted_probs, sorted_indices): |
|
total += prob.item() |
|
selected_indices.append(idx.item()) |
|
selected_probs.append(prob.item()) |
|
if total >= threshold: |
|
break |
|
|
|
ans_themes = [ind_to_category[idx] for idx in selected_indices] |
|
return ans_themes, selected_probs |
|
|
|
|
|
|
|
|
|
st.title("📄 Article Theme Classifier") |
|
|
|
title = st.text_input("Title", value="Введите title...") |
|
abstract = st.text_input("Abstract", value="Введите abstract...") |
|
threshold = st.slider("Выберите cumulative probability threshold", 0.0, 1.0, step=0.01, value=0.95) |
|
|
|
if st.button("Submit"): |
|
if title or abstract: |
|
st.success(f"✅ Title") |
|
st.info(f"📑 Abstract") |
|
themes, probs = inference(title, abstract, threshold) |
|
st.subheader("Predicted Themes:") |
|
for i in range(len(themes)): |
|
st.write(f"**{themes[i]}** — {probs[i]:.4f}") |
|
else: |
|
st.warning("❌ Please fill in at least one of the fields.") |
|
|