import streamlit as st import torch from torch import nn import csv from transformers import AutoModel, AutoTokenizer from huggingface_hub import hf_hub_download from model import ClassificationModel st.set_page_config(page_title="Article Theme Classifier", layout="centered") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_LENGTH = 512 @st.cache_resource def get_model(): base_model = AutoModel.from_pretrained("distilbert-base-cased") class_model = ClassificationModel(base_model) weights_path = hf_hub_download( repo_id="MostoHF/TunedDistillBertCased", filename="pytorch_model.bin" ) state_dict = torch.load(weights_path, map_location=device) class_model.load_state_dict(state_dict) class_model.to(device) class_model.eval() return class_model @st.cache_resource def get_tokenizer(): return AutoTokenizer.from_pretrained("distilbert-base-cased") @st.cache_resource def get_ind_to_cat(): ind_to_category_copy = {} with open('ind_to_category.csv', mode='r', newline='') as f: reader = csv.reader(f) next(reader) # skip header for key, value in reader: ind_to_category_copy[int(key)] = value # ключи — int return ind_to_category_copy class_model = get_model() tokenizer = get_tokenizer() ind_to_category = get_ind_to_cat() def inference(title, abstract, threshold=0.95): cur_elem = title + '@' + abstract encoding = tokenizer(cur_elem, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) with torch.no_grad(): res_probs = torch.exp(class_model(input_ids, attention_mask)) probs = res_probs.squeeze(0) # (8,) sorted_probs, sorted_indices = torch.sort(probs, descending=True) total = 0.0 selected_indices = [] selected_probs = [] for prob, idx in zip(sorted_probs, sorted_indices): total += prob.item() selected_indices.append(idx.item()) selected_probs.append(prob.item()) if total >= threshold: break ans_themes = [ind_to_category[idx] for idx in selected_indices] return ans_themes, selected_probs # ------------------- Streamlit UI ------------------- st.title("📄 Article Theme Classifier") title = st.text_input("Title", value="Введите title...") abstract = st.text_input("Abstract", value="Введите abstract...") threshold = st.slider("Выберите cumulative probability threshold", 0.0, 1.0, step=0.01, value=0.95) if st.button("Submit"): if title or abstract: st.success(f"✅ Title") st.info(f"📑 Abstract") themes, probs = inference(title, abstract, threshold) st.subheader("Predicted Themes:") for i in range(len(themes)): st.write(f"**{themes[i]}** — {probs[i]:.4f}") else: st.warning("❌ Please fill in at least one of the fields.")