|
import streamlit as st |
|
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
|
from transformers import Trainer, TrainingArguments |
|
import torch |
|
import json |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
import numpy as np |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_path): |
|
model = DistilBertForSequenceClassification.from_pretrained(model_path) |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') |
|
|
|
model.to("cpu") |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') |
|
return model, tokenizer |
|
|
|
|
|
def predict(text, model, tokenizer, threshold=0.5): |
|
with open("classes.json", "r") as f: |
|
classes = json.load(f) |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
|
|
inputs = {k: v.to("cpu") for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
probs = np.array(torch.nn.functional.softmax(logits)[0]) |
|
|
|
idx = np.argsort(probs)[::-1] |
|
|
|
tags = [] |
|
cumsum = 0.0 |
|
ind = 0 |
|
while cumsum <= 0.95: |
|
tags.append(classes[idx[ind]]) |
|
cumsum += probs[idx[ind]] |
|
ind += 1 |
|
|
|
|
|
return tags |
|
|
|
|
|
|
|
model, tokenizer = load_model("./results/checkpoint-200") |
|
|
|
st.title("Multilabel article classification") |
|
st.header("Based on title and summary") |
|
|
|
st.text_input("Input title", key="title") |
|
st.text_input("Input summary", key="summary") |
|
|
|
if (st.session_state["title"] or st.session_state["summary"]): |
|
query = "TITLE:" + st.session_state['title'] + ", SUMMARY:" + st.session_state['summary'] |
|
tags = predict(query, model, tokenizer) |
|
|
|
st.text("Predicted Tags:") |
|
st.text(", ".join(tags)) |
|
|