import streamlit as st from transformers import DistilBertModel, DistilBertTokenizer import torch model_path = './models/pytorch_distilbert.bin' vocab_path = './models/vocab_distilbert.bin' device = torch.device('cpu') MAX_LEN = 512 labels_description = {0: 'Computer Science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Mathematics', 4: 'Physics', 5: 'Quantitative Biology', 6: 'Quantitative Finance', 7: 'Statistics'} class DistillBERTClass(torch.nn.Module): def __init__(self): super(DistillBERTClass, self).__init__() self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 8) def forward(self, input_ids, attention_mask): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output def predict(text, model, human_readable=True): model.eval() text = " ".join(text.split()) inputs = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True, truncation=True ) ids = torch.tensor(inputs['input_ids'], dtype=torch.long) ids = torch.reshape(ids, (1, MAX_LEN)) mask = torch.tensor(inputs['attention_mask'], dtype=torch.long) mask = torch.reshape(mask, (1, MAX_LEN)) with torch.no_grad(): outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist() result = [] for i, v in enumerate(outputs): result.append((v, i)) result.sort(reverse=True) pr = 0.0 index = 0 answer = [] while pr < 0.95: pr += result[index][0] if not human_readable: answer.append(result[index][1]) else: answer.append(labels_description[result[index][1]] + " - {:.2f}%".format(100 * result[index][0])) index += 1 return answer @st.cache(show_spinner=False, allow_output_mutation=True) def load_model_and_tokenizer(): return (torch.load(model_path, map_location=torch.device(device)), DistilBertTokenizer.from_pretrained(vocab_path)) model, tokenizer = load_model_and_tokenizer() st.markdown("### Hi! This is a service for determining the subject of an article.") st.markdown("It can predict the following topics:\n" "* Computer Science\n" "* Economics\n" "* Electrical Engineering and Systems Science\n" "* Mathematics\n" "* Physics\n" "* Quantitative Biology\n" "* Quantitative Finance\n" "* Statistics\n") st.markdown("#### Just write the title and abstract in the areas below and click the \"Analyze\" button.") title = st.text_area("Title") abstract = st.text_area("Abstract") if st.button('Analyze'): with st.spinner("Wait..."): if not title and not abstract: st.error(f"You haven't written anything.") elif not title: st.error(f"You haven't written a title.") else: pred = predict(title+"\n"+abstract, model.to(device)) st.success("\n\n".join(pred))