|
import streamlit as st |
|
from transformers import DistilBertModel, DistilBertTokenizer |
|
import torch |
|
import matplotlib.image as mpimg |
|
|
|
img = mpimg.imread('./460.jpeg') |
|
model_path = './models/pytorch_distilbert.bin' |
|
vocab_path = './models/vocab_distilbert.bin' |
|
device = torch.device('cpu') |
|
MAX_LEN = 512 |
|
|
|
labels_description = {0: 'Computer Science', |
|
1: 'Economics', |
|
2: 'Electrical Engineering and Systems Science', |
|
3: 'Mathematics', |
|
4: 'Physics', |
|
5: 'Quantitative Biology', |
|
6: 'Quantitative Finance', |
|
7: 'Statistics'} |
|
|
|
|
|
class DistillBERTClass(torch.nn.Module): |
|
def __init__(self): |
|
super(DistillBERTClass, self).__init__() |
|
self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased") |
|
self.pre_classifier = torch.nn.Linear(768, 768) |
|
self.dropout = torch.nn.Dropout(0.3) |
|
self.classifier = torch.nn.Linear(768, 8) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_state = output_1[0] |
|
pooler = hidden_state[:, 0] |
|
pooler = self.pre_classifier(pooler) |
|
pooler = torch.nn.ReLU()(pooler) |
|
pooler = self.dropout(pooler) |
|
output = self.classifier(pooler) |
|
return output |
|
|
|
|
|
def predict(text, model, human_readable=True): |
|
model.eval() |
|
text = " ".join(text.split()) |
|
inputs = tokenizer.encode_plus( |
|
text, |
|
None, |
|
add_special_tokens=True, |
|
max_length=MAX_LEN, |
|
pad_to_max_length=True, |
|
return_token_type_ids=True, |
|
truncation=True |
|
) |
|
ids = torch.tensor(inputs['input_ids'], dtype=torch.long) |
|
ids = torch.reshape(ids, (1, MAX_LEN)) |
|
mask = torch.tensor(inputs['attention_mask'], dtype=torch.long) |
|
mask = torch.reshape(mask, (1, MAX_LEN)) |
|
with torch.no_grad(): |
|
outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist() |
|
result = [] |
|
for i, v in enumerate(outputs): |
|
result.append((v, i)) |
|
result.sort(reverse=True) |
|
pr = 0.0 |
|
index = 0 |
|
answer = [] |
|
while pr < 0.95: |
|
pr += result[index][0] |
|
if not human_readable: |
|
answer.append(result[index][1]) |
|
else: |
|
answer.append(labels_description[result[index][1]] + " {:.2f}%".format(100 * result[index][0])) |
|
index += 1 |
|
|
|
return answer |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained(vocab_path) |
|
model = torch.load(model_path, map_location=torch.device(device)) |
|
|
|
st.markdown("### Hi! This is a service for determining the subject of an article.") |
|
st.image(img) |
|
st.markdown("### Just write the title and content in the areas below and click the \"Analyze\" button.") |
|
|
|
text1 = st.text_area("Title") |
|
|
|
text2 = st.text_area("Summary") |
|
|
|
if st.button('Analyse'): |
|
with st.spinner("Wait..."): |
|
if text1 or text2: |
|
pred = predict(text1+"\n"+text2, model.to(device)) |
|
st.success("\n\n".join(pred)) |
|
else: |
|
st.error(f"You haven't written anything.") |