Spaces:
Runtime error
Runtime error
File size: 1,571 Bytes
388a2fe 49e8106 b60d6e8 49e8106 388a2fe 49e8106 b60d6e8 49e8106 388a2fe 49e8106 b60d6e8 49e8106 b60d6e8 49e8106 b60d6e8 49e8106 b60d6e8 49e8106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import pandas as pd
model_path = 'model'
device = torch.device('cpu')
model_name = 'distilbert-base-cased'
genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir',
'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music',
'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western'])
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres))
model.load_state_dict(torch.load(model_path))
return model, tokenizer
def predict(text: str, tokenizer, model):
tokens = tokenizer.encode(text)
probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0]
top_5_index = probas.argsort()[:-6:-1]
return dict(zip(genres[top_5_index], probas[top_5_index]))
st.markdown("### Movie genre classification")
model, tokenizer = load_model()
text = st.text_area("Write some movie description")
if st.button('Predict'):
with st.spinner("Wait..."):
if not text:
st.error("Write something.")
else:
pred = predict(text, tokenizer, model)
result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability'])
st.write(result)
|