HW-abstracts / app.py
dashasabirova's picture
Update app.py
348a64a
raw history blame
No virus
1.79 kB
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
from scipy.special import softmax
labels_articles = {1: 'Computer Science', 2: 'Economics', 3: "Electrical Engineering And Systems Science",
4: "Mathematics", 5: "Physics", 6: "Quantitative Biology", 7: "Quantitative Finance",
8: "Statistics"
}
@st.cache(allow_output_mutation=True)
def models():
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.layer = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 8),
)
def forward(self,x):
return self.layer(x)
model_second = Net()
model_second.load_state_dict(torch.load('model.txt'))
model_second.eval()
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_first = AutoModel.from_pretrained(model_name)
return (model_first, model_second, tokenizer)
model_first, model_second, tokenizer = models()
title = st.text_area("Write the title of your article, please")
abstract = st.text_area("Write the abstract")
text = title + '. ' + abstract
tokens_info = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
out_first = model_first(**tokens_info).pooler_output
out_second = model_second(out_first).detach().numpy()
out_second = softmax(out_second)
indices = np.argsort(out_second)[0][::-1]
sum_prob = 0
for i in indices:
st.write(labels_articles[i+1])
sum_prob += out_second[0][i]
if sum_prob >= 0.95:
break