movie-genre / app.py
A-M-S's picture
Merge branch 'main' of https://huggingface.co/spaces/A-M-S/movie-genre
27d72f6
import streamlit as st
from transformers import AutoModelForSequenceClassification
import pandas as pd
import numpy as np
import torch
import pickle
import wikipedia
from preprocess import Preprocess
from utility import Utility
st.title("Movie Genre Predictor")
text = st.text_input('Enter plot of the movie')
st.caption("Either enter Wiki URL or the Cast info of the movie. Cast will be fetched from the Wiki page if cast is not provided")
wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-49092")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
lr_model = pickle.load(open("models/cast_plot_lr","rb"))
cast_mlb = pickle.load(open("models/cast_mlb","rb"))
#column_names = pickle.load(open("models/column_names","rb"))
#top_actors = list(column_names)[11:]
top_actors = list(cast_mlb.classes_)
meta_model = pickle.load(open("models/meta_model","rb"))
utility = Utility()
preprocess = Preprocess()
out = []
if st.button("Predict"):
cast = []
if len(wiki_url)!=0 and len(cast_input)==0:
html_page = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).html()
cast_wiki = html_page.split(" title=\"Edit section: Cast\">edit</a>")[-1]
anchor_tags = cast_wiki.split("<a href=")[1:6]
top5_cast_links = [val.split("\"")[1] for val in anchor_tags]
for actor in top5_cast_links:
try:
cast.append(wikipedia.page(title=actor.split("/")[-1].replace("_"," ")).pageid)
except:
pass
else:
if ", " in cast_input:
cast = cast_input.split(", ")
else:
cast = cast_input.split(",")
cast_str = ""
for actor in cast:
cast_str += actor + ", "
st.write("Wiki Ids of Top 5 Cast:",cast_str)
st.write("Genre: ")
clean_plot = preprocess.apply(text)
# Use Meta Model approach when cast information is available otherwise use DistilBERT model
if len(cast)!=0:
# Base Model 1: DistilBERT
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
input_ids = [np.asarray(tokenized_plot['input_ids'])]
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
pred = torch.FloatTensor(y_pred['logits'][0])
sigmoid = torch.nn.Sigmoid()
distilbert_pred = sigmoid(pred.squeeze().cpu())
# Base model 2: LR One Vs All
cast_features = []
for actor in cast:
if actor in top_actors:
cast_features.append(str(actor))
lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features]))
# Concatenating Outputs of base models
r1 = distilbert_pred[3]
r2 = distilbert_pred[5]
r3 = distilbert_pred[1]
r4 = distilbert_pred[6]
r5 = distilbert_pred[2]
r6 = distilbert_pred[4]
distilbert_pred[1] = r1
distilbert_pred[2] = r2
distilbert_pred[3] = r3
distilbert_pred[4] = r4
distilbert_pred[5] = r5
distilbert_pred[6] = r6
pred1 = distilbert_pred
pred2 = lr_model_pred
distilbert_pred = pred1.detach().numpy()
lr_model_pred = np.array(pred2)[0]
concat_features = np.concatenate((lr_model_pred,distilbert_pred))
# Meta model 3: LR One Vs All
probs = meta_model.predict_proba([concat_features])
# Preparing Output
id2label = {0: "Action",1: "Comedy",2: "Crime",3: "Drama",4: "Horror",5: "Romance",6: "Thriller"}
i = 0
for prob in probs[0]:
out.append([id2label[i], prob])
i += 1
else:
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
input_ids = [np.asarray(tokenized_plot['input_ids'])]
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
pred = torch.FloatTensor(y_pred['logits'][0])
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(pred.squeeze().cpu())
i = 0
for prob in probs:
out.append([id2label[i], np.asscalar(prob)])
i += 1
st.write(out)