movie-genre / app.py
A-M-S's picture
Refactored the code
2eba150
raw history blame
No virus
3.35 kB
import streamlit as st
from transformers import AutoModelForSequenceClassification
import pandas as pd
import numpy as np
import torch
import pickle
import wikipedia
from preprocess import Preprocess
from utility import Utility
st.title("Movie Genre Predictor")
st.subheader("Enter the text you'd like to analyze.")
text = st.text_input('Enter plot of the movie')
wiki_url = st.text_input("Enter wikipedia url of the movie (Needed for fetching the cast information)")
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
lr_model = pickle.load(open("models/cast_plot_lr","rb"))
cast_mlb = pickle.load(open("models/cast_mlb","rb"))
column_names = pickle.load(open("models/column_names","rb"))
top_actors = list(column_names)[11:]
meta_model = pickle.load(open("models/meta_model","rb"))
utility = Utility()
preprocess = Preprocess()
if st.button("Predict"):
cast = []
if len(wiki_url)!=0:
cast_wiki = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).section("Cast")
cast_names = [val.split(" as ")[0] for val in cast_wiki.split("\n")]
for actor in cast_names[:5]:
try:
cast.append(wikipedia.page(title=actor).pageid)
except:
search_results = wikipedia.search(actor,results=2)
try:
cast.append(wikipedia.page(title=search_results[0]).pageid)
except:
try:
cast.append(wikipedia.page(title=search_results(actor)[1]).pageid)
except:
pass
st.write("Wiki Ids of Top 5 Cast:",cast)
st.write("Genre: ")
clean_plot = preprocess.apply(text)
# Base Model 1: DistilBERT
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
input_ids = [np.asarray(tokenized_plot['input_ids'])]
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
pred = torch.FloatTensor(y_pred['logits'][0])
sigmoid = torch.nn.Sigmoid()
distilbert_pred = sigmoid(pred.squeeze().cpu())
# Base model 2: LR One Vs All
cast_features = []
for actor in cast:
if actor in top_actors:
cast_features.append(str(actor))
lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features]))
# Concatenating Outputs of base models
r1 = distilbert_pred[3]
r2 = distilbert_pred[1]
r3 = distilbert_pred[2]
distilbert_pred[1] = r1
distilbert_pred[2] = r2
distilbert_pred[3] = r3
pred1 = distilbert_pred
pred2 = lr_model_pred
distilbert_pred = pred1.detach().numpy()
lr_model_pred = np.array(pred2)[0]
concat_features = np.concatenate((lr_model_pred,distilbert_pred))
# Meta model 3: LR One Vs All
probs = meta_model.predict_proba([concat_features])
# Preparing Output
out = []
id2label = {0:"Action",1:"Comedy",2:"Drama",3:"Romance",4:"Thriller"}
i = 0
for prob in probs[0]:
out.append([id2label[i], prob])
i += 1
st.write(out)