import streamlit as st import torch import torch.nn as nn import torch.optim as optim import numpy as np import pandas as pd import sys import platform from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig, AutoModel, PreTrainedModel from collections import OrderedDict class BertClassifier(nn.Module): def __init__(self, bert_model, num_classes=8): super(BertClassifier, self).__init__() self.bert = bert_model head = [ ('hid2out', nn.Linear(768, num_classes)), ('log_softmax', nn.LogSoftmax(dim=-1)) ] self.head = nn.Sequential(OrderedDict(head)) def forward(self, input_ids, attention_mask): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0] out = self.head(bert_output[:, 0, :]) return out @st.cache def loading_tokenizer_and_model(): bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') bert_model = DistilBertModel.from_pretrained("semen15362/shad_bert_v1") model = BertClassifier(bert_model) checkpoint = torch.load('model_head.txt', map_location=torch.device('cpu')) model.head.load_state_dict(checkpoint) return bert_tokenizer, model def classify_article(title: str, abstract: str = None): category_list = [ 'Statistics', 'Mathematics', 'Computer Science', 'Electrical Engineering and Systems Science', 'Quantitative Finance', 'Economics', 'Quantitative Biology', 'Physics' ] bert_tokenizer, model = loading_tokenizer_and_model() if abstract is None: abstract = '' texts = bert_tokenizer( [f"TITLE: {title} ABSTRACT: {abstract}"], padding=True, truncation=True, return_tensors='pt' ) model.eval() with torch.no_grad(): input_ids = texts.input_ids attention_mask = texts.attention_mask log_probs = model(input_ids, attention_mask) probs = torch.exp(log_probs) results = list(zip(category_list, probs[0].numpy())) results.sort(key=lambda x: x[1], reverse=True) cnt_95 = 0 sum_prob = 0.0 while cnt_95 < len(results) and sum_prob < 0.95: sum_prob += results[cnt_95][1] cnt_95 += 1 return results[:cnt_95] st.title("Article Classifier") title = st.text_input("Enter the article title:") abstract = st.text_area("Enter the article abstract (optional):") if st.button("Classify"): if title: results = classify_article(title, abstract) st.header("Classification Results:") for topic, probability in results: st.write(f"{topic}: {probability:.2f}") else: st.warning("Please enter an article title.") else: st.info("Enter the article title and abstract, then press the 'Classify' button.")