Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import numpy as np | |
import pandas as pd | |
import sys | |
import platform | |
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig, AutoModel, PreTrainedModel | |
from collections import OrderedDict | |
class BertClassifier(nn.Module): | |
def __init__(self, bert_model, num_classes=8): | |
super(BertClassifier, self).__init__() | |
self.bert = bert_model | |
head = [ | |
('hid2out', nn.Linear(768, num_classes)), | |
('log_softmax', nn.LogSoftmax(dim=-1)) | |
] | |
self.head = nn.Sequential(OrderedDict(head)) | |
def forward(self, input_ids, attention_mask): | |
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0] | |
out = self.head(bert_output[:, 0, :]) | |
return out | |
def loading_tokenizer_and_model(): | |
bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
bert_model = DistilBertModel.from_pretrained("semen15362/shad_bert_v1") | |
model = BertClassifier(bert_model) | |
checkpoint = torch.load('model_head.txt', map_location=torch.device('cpu')) | |
model.head.load_state_dict(checkpoint) | |
return bert_tokenizer, model | |
def classify_article(title: str, abstract: str = None): | |
category_list = [ | |
'Statistics', | |
'Mathematics', | |
'Computer Science', | |
'Electrical Engineering and Systems Science', | |
'Quantitative Finance', | |
'Economics', | |
'Quantitative Biology', | |
'Physics' | |
] | |
bert_tokenizer, model = loading_tokenizer_and_model() | |
if abstract is None: | |
abstract = '' | |
texts = bert_tokenizer( | |
[f"TITLE: {title} ABSTRACT: {abstract}"], | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
) | |
model.eval() | |
with torch.no_grad(): | |
input_ids = texts.input_ids | |
attention_mask = texts.attention_mask | |
log_probs = model(input_ids, attention_mask) | |
probs = torch.exp(log_probs) | |
results = list(zip(category_list, probs[0].numpy())) | |
results.sort(key=lambda x: x[1], reverse=True) | |
cnt_95 = 0 | |
sum_prob = 0.0 | |
while cnt_95 < len(results) and sum_prob < 0.95: | |
sum_prob += results[cnt_95][1] | |
cnt_95 += 1 | |
return results[:cnt_95] | |
st.title("Article Classifier") | |
title = st.text_input("Enter the article title:") | |
abstract = st.text_area("Enter the article abstract (optional):") | |
if st.button("Classify"): | |
if title: | |
results = classify_article(title, abstract) | |
st.header("Classification Results:") | |
for topic, probability in results: | |
st.write(f"{topic}: {probability:.2f}") | |
else: | |
st.warning("Please enter an article title.") | |
else: | |
st.info("Enter the article title and abstract, then press the 'Classify' button.") | |