hellohugg's picture
Upload 5 files
907d348 verified
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments
import torch
import json
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
@st.cache_resource
def load_model(model_path):
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model.to("cpu")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
return model, tokenizer
def predict(text, model, tokenizer, threshold=0.5):
with open("classes.json", "r") as f:
classes = json.load(f)
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Move to CPU (if not already)
inputs = {k: v.to("cpu") for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = model(**inputs)
# Get logits and apply sigmoid for multi-label classification
logits = outputs.logits
probs = np.array(torch.nn.functional.softmax(logits)[0])
# print(probs) # Convert to probabilities
idx = np.argsort(probs)[::-1]
# print(probs[idx])
tags = []
cumsum = 0.0
ind = 0
while cumsum <= 0.95:
tags.append(classes[idx[ind]])
cumsum += probs[idx[ind]]
ind += 1
return tags
model, tokenizer = load_model("./results/checkpoint-200")
st.title("Multilabel article classification")
st.header("Based on title and summary")
st.text_input("Input title", key="title")
st.text_input("Input summary", key="summary")
if (st.session_state["title"] or st.session_state["summary"]):
query = "TITLE:" + st.session_state['title'] + ", SUMMARY:" + st.session_state['summary']
tags = predict(query, model, tokenizer)
st.text("Predicted Tags:")
st.text(", ".join(tags))