Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| import pandas as pd | |
| import sys | |
| import platform | |
| from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig, AutoModel, PreTrainedModel | |
| from collections import OrderedDict | |
| class BertClassifier(nn.Module): | |
| def __init__(self, bert_model, num_classes=8): | |
| super(BertClassifier, self).__init__() | |
| self.bert = bert_model | |
| head = [ | |
| ('hid2out', nn.Linear(768, num_classes)), | |
| ('log_softmax', nn.LogSoftmax(dim=-1)) | |
| ] | |
| self.head = nn.Sequential(OrderedDict(head)) | |
| def forward(self, input_ids, attention_mask): | |
| bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0] | |
| out = self.head(bert_output[:, 0, :]) | |
| return out | |
| def loading_tokenizer_and_model(): | |
| bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
| bert_model = DistilBertModel.from_pretrained("semen15362/shad_bert_v1") | |
| model = BertClassifier(bert_model) | |
| checkpoint = torch.load('model_head.txt', map_location=torch.device('cpu')) | |
| model.head.load_state_dict(checkpoint) | |
| return bert_tokenizer, model | |
| def classify_article(title: str, abstract: str = None): | |
| category_list = [ | |
| 'Statistics', | |
| 'Mathematics', | |
| 'Computer Science', | |
| 'Electrical Engineering and Systems Science', | |
| 'Quantitative Finance', | |
| 'Economics', | |
| 'Quantitative Biology', | |
| 'Physics' | |
| ] | |
| bert_tokenizer, model = loading_tokenizer_and_model() | |
| if abstract is None: | |
| abstract = '' | |
| texts = bert_tokenizer( | |
| [f"TITLE: {title} ABSTRACT: {abstract}"], | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| model.eval() | |
| with torch.no_grad(): | |
| input_ids = texts.input_ids | |
| attention_mask = texts.attention_mask | |
| log_probs = model(input_ids, attention_mask) | |
| probs = torch.exp(log_probs) | |
| results = list(zip(category_list, probs[0].numpy())) | |
| results.sort(key=lambda x: x[1], reverse=True) | |
| cnt_95 = 0 | |
| sum_prob = 0.0 | |
| while cnt_95 < len(results) and sum_prob < 0.95: | |
| sum_prob += results[cnt_95][1] | |
| cnt_95 += 1 | |
| return results[:cnt_95] | |
| st.title("Article Classifier") | |
| title = st.text_input("Enter the article title:") | |
| abstract = st.text_area("Enter the article abstract (optional):") | |
| if st.button("Classify"): | |
| if title: | |
| results = classify_article(title, abstract) | |
| st.header("Classification Results:") | |
| for topic, probability in results: | |
| st.write(f"{topic}: {probability:.2f}") | |
| else: | |
| st.warning("Please enter an article title.") | |
| else: | |
| st.info("Enter the article title and abstract, then press the 'Classify' button.") | |