Konstantin Gordeev
Update model
49e8106
raw
history blame contribute delete
No virus
1.57 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import pandas as pd
model_path = 'model'
device = torch.device('cpu')
model_name = 'distilbert-base-cased'
genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir',
'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music',
'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western'])
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres))
model.load_state_dict(torch.load(model_path))
return model, tokenizer
def predict(text: str, tokenizer, model):
tokens = tokenizer.encode(text)
probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0]
top_5_index = probas.argsort()[:-6:-1]
return dict(zip(genres[top_5_index], probas[top_5_index]))
st.markdown("### Movie genre classification")
model, tokenizer = load_model()
text = st.text_area("Write some movie description")
if st.button('Predict'):
with st.spinner("Wait..."):
if not text:
st.error("Write something.")
else:
pred = predict(text, tokenizer, model)
result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability'])
st.write(result)