File size: 1,571 Bytes
388a2fe
49e8106
b60d6e8
49e8106
 
388a2fe
49e8106
b60d6e8
49e8106
388a2fe
49e8106
 
 
b60d6e8
 
 
 
49e8106
 
 
b60d6e8
 
 
49e8106
 
 
 
 
 
b60d6e8
 
 
49e8106
 
b60d6e8
 
 
 
 
 
 
49e8106
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
import pandas as pd

model_path = 'model'
device = torch.device('cpu')
model_name = 'distilbert-base-cased'

genres = np.array(['Animation', 'Comedy', 'Adult', 'Adventure', 'Musical', 'History', 'Reality-TV', 'Film-Noir',
 'Sport', 'Biography', 'Drama', 'Fantasy', 'Romance', 'Thriller', 'News', 'Documentary', 'Sci-Fi', 'Music',
 'Family', 'Mystery', 'Crime', 'Horror', 'War', 'Action', 'Western'])


@st.cache(allow_output_mutation=True)
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(genres))
    model.load_state_dict(torch.load(model_path))
    return model, tokenizer


def predict(text: str, tokenizer, model):
    tokens = tokenizer.encode(text)
    probas = torch.nn.Softmax(dim=1)(model(torch.as_tensor([tokens], device=device))[0]).detach().numpy()[0]
    top_5_index = probas.argsort()[:-6:-1]
    return dict(zip(genres[top_5_index], probas[top_5_index]))


st.markdown("### Movie genre classification")

model, tokenizer = load_model()

text = st.text_area("Write some movie description")

if st.button('Predict'):
    with st.spinner("Wait..."):
        if not text:
            st.error("Write something.")
        else:
            pred = predict(text, tokenizer, model)
            result = pd.DataFrame(list(pred.values()), index=list(pred.keys()), columns=['Probability'])
            st.write(result)