import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import numpy as np import pandas as pd model_path = 'model' device = torch.device('cpu') model_name = 'distilbert-base-cased' genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir', 'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music', 'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western']) @st.cache(allow_output_mutation=True) def load_model(): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres)) model.load_state_dict(torch.load(model_path)) return model, tokenizer def predict(text: str, tokenizer, model): tokens = tokenizer.encode(text) probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0] top_5_index = probas.argsort()[:-6:-1] return dict(zip(genres[top_5_index], probas[top_5_index])) st.markdown("### Movie genre classification") model, tokenizer = load_model() text = st.text_area("Write some movie description") if st.button('Predict'): with st.spinner("Wait..."): if not text: st.error("Write something.") else: pred = predict(text, tokenizer, model) result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability']) st.write(result)