HW-abstracts / app.py
dashasabirova's picture
Update app.py
2334258
raw history blame
No virus
1.53 kB
from scipy.special import softmax
import streamlit as st
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
labels_articles = {1: 'Computer Science', 2: 'Economics', 3: "Electrical Engineering And Systems Science",
4: "Mathematics", 5: "Physics", 6: "Quantitative Biology", 7: "Quantitative Finance",
8: "Statistics"
}
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.layer = nn.Sequential(
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 8),
)
def forward(self,x):
return self.layer(x)
model_second = Net()
model_second.load_state_dict(torch.load('model.txt'))
model_second.eval()
model_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_first = AutoModel.from_pretrained(model_name)
title = st.text_area("Write the title of your article, please")
abstract = st.text_area("Write the abstract")
text = title + '. ' + abstract
tokens_info = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
out_first = model_first(**tokens_info).pooler_output
out_second = model_second(out_first).detach().numpy()
out_second = softmax(out_second)
indices = np.argsort(out_second)[0][::-1]
sum_prob = 0
for i in indices:
print(labels_articles[i+1])
sum_prob += out_second[i]
if sum_prob > 0.95:
break