SHAD / app.py
3v324v23's picture
fix
60e27ac
raw
history blame contribute delete
No virus
2.92 kB
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import sys
import platform
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig, AutoModel, PreTrainedModel
from collections import OrderedDict
class BertClassifier(nn.Module):
def __init__(self, bert_model, num_classes=8):
super(BertClassifier, self).__init__()
self.bert = bert_model
head = [
('hid2out', nn.Linear(768, num_classes)),
('log_softmax', nn.LogSoftmax(dim=-1))
]
self.head = nn.Sequential(OrderedDict(head))
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
out = self.head(bert_output[:, 0, :])
return out
@st.cache
def loading_tokenizer_and_model():
bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained("semen15362/shad_bert_v1")
model = BertClassifier(bert_model)
checkpoint = torch.load('model_head.txt', map_location=torch.device('cpu'))
model.head.load_state_dict(checkpoint)
return bert_tokenizer, model
def classify_article(title: str, abstract: str = None):
category_list = [
'Statistics',
'Mathematics',
'Computer Science',
'Electrical Engineering and Systems Science',
'Quantitative Finance',
'Economics',
'Quantitative Biology',
'Physics'
]
bert_tokenizer, model = loading_tokenizer_and_model()
if abstract is None:
abstract = ''
texts = bert_tokenizer(
[f"TITLE: {title} ABSTRACT: {abstract}"],
padding=True,
truncation=True,
return_tensors='pt'
)
model.eval()
with torch.no_grad():
input_ids = texts.input_ids
attention_mask = texts.attention_mask
log_probs = model(input_ids, attention_mask)
probs = torch.exp(log_probs)
results = list(zip(category_list, probs[0].numpy()))
results.sort(key=lambda x: x[1], reverse=True)
cnt_95 = 0
sum_prob = 0.0
while cnt_95 < len(results) and sum_prob < 0.95:
sum_prob += results[cnt_95][1]
cnt_95 += 1
return results[:cnt_95]
st.title("Article Classifier")
title = st.text_input("Enter the article title:")
abstract = st.text_area("Enter the article abstract (optional):")
if st.button("Classify"):
if title:
results = classify_article(title, abstract)
st.header("Classification Results:")
for topic, probability in results:
st.write(f"{topic}: {probability:.2f}")
else:
st.warning("Please enter an article title.")
else:
st.info("Enter the article title and abstract, then press the 'Classify' button.")