Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForSequenceClassification | |
import pandas as pd | |
import numpy as np | |
import torch | |
import pickle | |
import wikipedia | |
from preprocess import Preprocess | |
from utility import Utility | |
st.title("Movie Genre Predictor") | |
text = st.text_input('Enter plot of the movie') | |
st.caption("Either enter Wiki URL or the Cast info of the movie. Cast will be fetched from the Wiki page if cast is not provided") | |
wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)") | |
cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)") | |
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-49092") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model.to(device) | |
lr_model = pickle.load(open("models/cast_plot_lr","rb")) | |
cast_mlb = pickle.load(open("models/cast_mlb","rb")) | |
#column_names = pickle.load(open("models/column_names","rb")) | |
#top_actors = list(column_names)[11:] | |
top_actors = list(cast_mlb.classes_) | |
meta_model = pickle.load(open("models/meta_model","rb")) | |
utility = Utility() | |
preprocess = Preprocess() | |
out = [] | |
if st.button("Predict"): | |
cast = [] | |
if len(wiki_url)!=0 and len(cast_input)==0: | |
html_page = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).html() | |
cast_wiki = html_page.split(" title=\"Edit section: Cast\">edit</a>")[-1] | |
anchor_tags = cast_wiki.split("<a href=")[1:6] | |
top5_cast_links = [val.split("\"")[1] for val in anchor_tags] | |
for actor in top5_cast_links: | |
try: | |
cast.append(wikipedia.page(title=actor.split("/")[-1].replace("_"," ")).pageid) | |
except: | |
pass | |
else: | |
if ", " in cast_input: | |
cast = cast_input.split(", ") | |
else: | |
cast = cast_input.split(",") | |
cast_str = "" | |
for actor in cast: | |
cast_str += actor + ", " | |
st.write("Wiki Ids of Top 5 Cast:",cast_str) | |
st.write("Genre: ") | |
clean_plot = preprocess.apply(text) | |
# Use Meta Model approach when cast information is available otherwise use DistilBERT model | |
if len(cast)!=0: | |
# Base Model 1: DistilBERT | |
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"]) | |
input_ids = [np.asarray(tokenized_plot['input_ids'])] | |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])] | |
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask)) | |
pred = torch.FloatTensor(y_pred['logits'][0]) | |
sigmoid = torch.nn.Sigmoid() | |
distilbert_pred = sigmoid(pred.squeeze().cpu()) | |
# Base model 2: LR One Vs All | |
cast_features = [] | |
for actor in cast: | |
if actor in top_actors: | |
cast_features.append(str(actor)) | |
lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features])) | |
# Concatenating Outputs of base models | |
r1 = distilbert_pred[3] | |
r2 = distilbert_pred[5] | |
r3 = distilbert_pred[1] | |
r4 = distilbert_pred[6] | |
r5 = distilbert_pred[2] | |
r6 = distilbert_pred[4] | |
distilbert_pred[1] = r1 | |
distilbert_pred[2] = r2 | |
distilbert_pred[3] = r3 | |
distilbert_pred[4] = r4 | |
distilbert_pred[5] = r5 | |
distilbert_pred[6] = r6 | |
pred1 = distilbert_pred | |
pred2 = lr_model_pred | |
distilbert_pred = pred1.detach().numpy() | |
lr_model_pred = np.array(pred2)[0] | |
concat_features = np.concatenate((lr_model_pred,distilbert_pred)) | |
# Meta model 3: LR One Vs All | |
probs = meta_model.predict_proba([concat_features]) | |
# Preparing Output | |
id2label = {0: "Action",1: "Comedy",2: "Crime",3: "Drama",4: "Horror",5: "Romance",6: "Thriller"} | |
i = 0 | |
for prob in probs[0]: | |
out.append([id2label[i], prob]) | |
i += 1 | |
else: | |
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"]) | |
input_ids = [np.asarray(tokenized_plot['input_ids'])] | |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])] | |
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask)) | |
pred = torch.FloatTensor(y_pred['logits'][0]) | |
sigmoid = torch.nn.Sigmoid() | |
probs = sigmoid(pred.squeeze().cpu()) | |
i = 0 | |
for prob in probs: | |
out.append([id2label[i], np.asscalar(prob)]) | |
i += 1 | |
st.write(out) |