Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import DistilBertTokenizer, DistilBertModel | |
import torch | |
import torch.nn as nn | |
from utils import get_answer_with_desc | |
MY_LINEAR_NAME = "my_linear_logits_3" | |
st.markdown("## Classifying articles on computer science!") | |
st.markdown("<img width=400px src='https://img.freepik.com/free-photo/young-pretty-student-overwhelmed-with-books_272645-183.jpg?size=626&ext=jpg'>", unsafe_allow_html=True) | |
title = st.text_area("Enter the title of the article", height=20) | |
abstract = st.text_area("Enter the abstract of the article", height=350) | |
top_k = st.slider('How many topics from top to show?', 1, 10, 3) | |
text = title + " " + abstract | |
def get_model_tokenizer_linear(): | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") | |
model = DistilBertModel.from_pretrained("distilbert-base-cased") | |
n_classes = 40 | |
my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True) | |
my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu'))) | |
return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear} | |
if len(text) == 1: | |
st.markdown("Input is empty, write something!") | |
else: | |
for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()): | |
st.markdown("#### " + ms) | |