Clf / app.py
Nikendolo's picture
Update app.py
e780d1a
raw
history blame
3.14 kB
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizer
import torch
import matplotlib.image as mpimg
img = mpimg.imread('./460.jpeg')
model_path = './models/pytorch_distilbert.bin'
vocab_path = './models/vocab_distilbert.bin'
device = torch.device('cpu')
MAX_LEN = 512
labels_description = {0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'}
class DistillBERTClass(torch.nn.Module):
def __init__(self):
super(DistillBERTClass, self).__init__()
self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased")
self.pre_classifier = torch.nn.Linear(768, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 8)
def forward(self, input_ids, attention_mask):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
def predict(text, model, human_readable=True):
model.eval()
text = " ".join(text.split())
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=MAX_LEN,
pad_to_max_length=True,
return_token_type_ids=True,
truncation=True
)
ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
ids = torch.reshape(ids, (1, MAX_LEN))
mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
mask = torch.reshape(mask, (1, MAX_LEN))
with torch.no_grad():
outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist()
result = []
for i, v in enumerate(outputs):
result.append((v, i))
result.sort(reverse=True)
pr = 0.0
index = 0
answer = []
while pr < 0.95:
pr += result[index][0]
if not human_readable:
answer.append(result[index][1])
else:
answer.append(labels_description[result[index][1]] + " {:.2f}%".format(100 * result[index][0]))
index += 1
return answer
tokenizer = DistilBertTokenizer.from_pretrained(vocab_path)
model = torch.load(model_path, map_location=torch.device(device))
st.markdown("### Hi! This is a service for determining the subject of an article.")
st.image(img)
st.markdown("### Just write the title and content in the areas below and click the \"Analyze\" button.")
text1 = st.text_area("Title")
text2 = st.text_area("Summary")
if st.button('Analyse'):
with st.spinner("Wait..."):
if text1 or text2:
pred = predict(text1+"\n"+text2, model.to(device))
st.success("\n\n".join(pred))
else:
st.error(f"You haven't written anything.")