Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from torch import nn | |
from transformers import BertModel, AutoTokenizer, AutoModel, pipeline | |
from time import time | |
import matplotlib.pyplot as plt | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = 'cpu' | |
# dict for decoding / enclding labels | |
labels = {'cs.NE': 0, 'cs.CL': 1, 'cs.AI': 2, 'stat.ML': 3, 'cs.CV': 4, 'cs.LG': 5} | |
labels_decoder = {'cs.NE': 'Neural and Evolutionary Computing', 'cs.CL': 'Computation and Language', 'cs.AI': 'Artificial Intelligence', | |
'stat.ML': 'Machine Learning (stat)', 'cs.CV': 'Computer Vision', 'cs.LG': 'Machine Learning'} | |
model_name = 'bert-base-uncased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
class BertClassifier(nn.Module): | |
def __init__(self, n_classes, dropout=0.5, model_name='bert-base-uncased'): | |
super(BertClassifier, self).__init__() | |
self.bert = BertModel.from_pretrained(model_name) | |
self.dropout = nn.Dropout(dropout) | |
self.linear = nn.Linear(768, n_classes) | |
self.relu = nn.ReLU() | |
def forward(self, input_id, mask): | |
_, pooled_output = self.bert(input_ids=input_id, attention_mask=mask,return_dict=False) | |
dropout_output = self.dropout(pooled_output) | |
linear_output = self.linear(dropout_output) | |
final_layer = self.relu(linear_output) | |
return final_layer | |
def build_model(): | |
model = BertClassifier(n_classes=len(labels)) | |
st.markdown("Model created") | |
model.load_state_dict(torch.load('model_weights_1.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
st.markdown("Model weights loaded") | |
return model | |
def inference(txt, mode=None): | |
# infers classes for text topic based on the trained model from above | |
# has separate mode 'print' for just output | |
t2 = tokenizer(txt.lower().replace('\n', ''), | |
padding='max_length', max_length = 512, truncation=True, | |
return_tensors="pt") | |
inp2 = t2['input_ids'].to(device) | |
mask2 = t2['attention_mask'].unsqueeze(0).to(device) | |
out = model(inp2, mask2) | |
out = out.cpu().detach().numpy().reshape(-1) | |
out = out/out.sum() * 100 | |
res = [(l, o) for l, o in zip (list(labels.keys()), out.tolist())] | |
return res | |
model = build_model() | |
st.markdown("### Privet, mir!") | |
st.markdown("<img width=200px src='https://i.pinimg.com/736x/11/33/19/113319f0ffe91f4bb0f468914b9916da.jpg'>", unsafe_allow_html=True) | |
text = st.text_area("ENTER TEXT HERE") | |
start_time = time() | |
st.markdown("INFERENCE STARTS ...") | |
res = inference(text, mode=None) | |
res.sort(key = lambda x : - x[1]) | |
st.markdown("<b>INFERENCE RESULT:</b>") | |
for lbl, score in res: | |
if score >=1: | |
st.markdown(f"[ {lbl:<7}] {labels_decoder[lbl]:<35} {score:.1f}%") | |
res_plot = [] | |
total=0 | |
for r in res: | |
if total < 95: | |
res_plot.append(r) | |
total += r[1] | |
else: | |
break | |
fig, ax = plt.subplots(figsize=(10, len(res_plot)+1)) | |
for r in res_plot : | |
ax.barh(r[0], r[1]) | |
st.pyplot(fig) | |
st.markdown(f"cycle time = {time() - start_time:.2f} s.") |