Spaces:
Runtime error
Runtime error
File size: 935 Bytes
02ea321 362dc22 89899dd 362dc22 02ea321 362dc22 89899dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn as nn
from utils import get_answer_with_desc
MY_LINEAR_NAME = "my_linear_logits_3"
st.markdown("## Classifying articles on computer science!")
st.markdown("<img width=400px src='https://img.freepik.com/free-photo/young-pretty-student-overwhelmed-with-books_272645-183.jpg?size=626&ext=jpg'>", unsafe_allow_html=True)
title = st.text_area("Enter the title of the article")
abstract = st.text_area("Enter the abstract of the article")
text = title + " " + abstract
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
n_classes = 40
my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
my_linear.load_state_dict(torch.load(MY_LINEAR_NAME))
st.markdown(get_answer_with_desc(text, model, tokenizer, my_linear))
|