Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
import numpy as np | |
import pandas as pd | |
model_path = 'model' | |
device = torch.device('cpu') | |
model_name = 'distilbert-base-cased' | |
genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir', | |
'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music', | |
'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western']) | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres)) | |
model.load_state_dict(torch.load(model_path)) | |
return model, tokenizer | |
def predict(text: str, tokenizer, model): | |
tokens = tokenizer.encode(text) | |
probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0] | |
top_5_index = probas.argsort()[:-6:-1] | |
return dict(zip(genres[top_5_index], probas[top_5_index])) | |
st.markdown("### Movie genre classification") | |
model, tokenizer = load_model() | |
text = st.text_area("Write some movie description") | |
if st.button('Predict'): | |
with st.spinner("Wait..."): | |
if not text: | |
st.error("Write something.") | |
else: | |
pred = predict(text, tokenizer, model) | |
result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability']) | |
st.write(result) | |