import streamlit as st from transformers import DistilBertModel, DistilBertTokenizer import torch import matplotlib.image as mpimg img = mpimg.imread('./460.jpeg') model_path = './models/pytorch_distilbert.bin' vocab_path = './models/vocab_distilbert.bin' device = torch.device('cpu') MAX_LEN = 512 labels_description = {0: 'Computer Science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Mathematics', 4: 'Physics', 5: 'Quantitative Biology', 6: 'Quantitative Finance', 7: 'Statistics'} class DistillBERTClass(torch.nn.Module): def __init__(self): super(DistillBERTClass, self).__init__() self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 8) def forward(self, input_ids, attention_mask): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output def predict(text, model, human_readable=True): model.eval() text = " ".join(text.split()) inputs = tokenizer.encode_plus( text, None, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True, truncation=True ) ids = torch.tensor(inputs['input_ids'], dtype=torch.long) ids = torch.reshape(ids, (1, MAX_LEN)) mask = torch.tensor(inputs['attention_mask'], dtype=torch.long) mask = torch.reshape(mask, (1, MAX_LEN)) with torch.no_grad(): outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist() result = [] for i, v in enumerate(outputs): result.append((v, i)) result.sort(reverse=True) pr = 0.0 index = 0 answer = [] while pr < 0.95: pr += result[index][0] if not human_readable: answer.append(result[index][1]) else: answer.append(labels_description[result[index][1]] + " {:.2f}%".format(100 * result[index][0])) index += 1 return answer tokenizer = DistilBertTokenizer.from_pretrained(vocab_path) model = torch.load(model_path, map_location=torch.device(device)) st.markdown("### Hi! This is a service for determining the subject of an article.") st.image(img) st.markdown("### Just write the title and content in the areas below and click the \"Analyze\" button.") text1 = st.text_area("Title") text2 = st.text_area("Summary") if st.button('Analyse'): with st.spinner("Wait..."): if text1 or text2: pred = predict(text1+"\n"+text2, model.to(device)) st.success("\n\n".join(pred)) else: st.error(f"You haven't written anything.")