HW-abstracts / app.py
dashasabirova's picture
Update app.py
d93b6b1
raw
history blame
1.73 kB
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
from scipy.special import softmax
labels_articles = {1: 'Computer Science', 2: 'Economics', 3: "Electrical Engineering And Systems Science",
4: "Mathematics", 5: "Physics", 6: "Quantitative Biology", 7: "Quantitative Finance",
8: "Statistics"
}
def models():
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.layer = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 8),
)
def forward(self,x):
return self.layer(x)
model_second = Net()
model_second.load_state_dict(torch.load('model.txt'))
model_second.eval()
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_first = AutoModel.from_pretrained(model_name)
return (model_first, model_second, tokenizer)
model_first, model_second, tokenizer = models()
title = st.text_area("Write the title of your article, please")
abstract = st.text_area("Write the abstract")
text = title + '. ' + abstract
tokens_info = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
out_first = model_first(**tokens_info).pooler_output
out_second = model_second(out_first).detach().numpy()
out_second = softmax(out_second)
indices = np.argsort(out_second)[0][::-1]
sum_prob = 0
for i in indices:
st.write(labels_articles[i+1])
sum_prob += out_second[0][i]
if sum_prob >= 0.95:
break