|
import streamlit as st
|
|
import torch
|
|
import transformers
|
|
|
|
@st.cache
|
|
def load_model_n_tokenizer():
|
|
model = torch.load('pytorch_model.bin')
|
|
tknzer = transformers.RobertaTokenizer('distilroberta-base')
|
|
return model, tknzer
|
|
|
|
model, tokenizer = load_model_n_tokenizer()
|
|
|
|
st.markdown('### Hello, world!')
|
|
|
|
title = st.text_input('Enter title')
|
|
|
|
desc = st.text_input('Enter description')
|
|
|
|
st.markdown(title)
|
|
st.markdown(desc)
|
|
|
|
text = title + ' ' + desc
|
|
tokenized = tokenizer(text, padding = True, truncation = True)
|
|
with torch.no_grad():
|
|
st.markdown(model(tokenized))
|
|
|
|
|