Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForSequenceClassification | |
import pandas as pd | |
import numpy as np | |
import torch | |
import pickle | |
import wikipedia | |
from preprocess import Preprocess | |
from utility import Utility | |
st.title("Movie Genre Predictor") | |
st.subheader("Enter the text you'd like to analyze.") | |
text = st.text_input('Enter plot of the movie') | |
wiki_url = st.text_input("Enter wikipedia url of the movie (Needed for fetching the cast information)") | |
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819") | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model.to(device) | |
lr_model = pickle.load(open("models/cast_plot_lr","rb")) | |
cast_mlb = pickle.load(open("models/cast_mlb","rb")) | |
column_names = pickle.load(open("models/column_names","rb")) | |
top_actors = list(column_names)[11:] | |
meta_model = pickle.load(open("models/meta_model","rb")) | |
utility = Utility() | |
preprocess = Preprocess() | |
if st.button("Predict"): | |
cast = [] | |
if len(wiki_url)!=0: | |
cast_wiki = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).section("Cast") | |
cast_names = [val.split(" as ")[0] for val in cast_wiki.split("\n")] | |
for actor in cast_names[:5]: | |
try: | |
cast.append(wikipedia.page(title=actor).pageid) | |
except: | |
search_results = wikipedia.search(actor,results=2) | |
try: | |
cast.append(wikipedia.page(title=search_results[0]).pageid) | |
except: | |
try: | |
cast.append(wikipedia.page(title=search_results(actor)[1]).pageid) | |
except: | |
pass | |
st.write("Wiki Ids of Top 5 Cast:",cast) | |
st.write("Genre: ") | |
clean_plot = preprocess.apply(text) | |
# Base Model 1: DistilBERT | |
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"]) | |
input_ids = [np.asarray(tokenized_plot['input_ids'])] | |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])] | |
y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask)) | |
pred = torch.FloatTensor(y_pred['logits'][0]) | |
sigmoid = torch.nn.Sigmoid() | |
distilbert_pred = sigmoid(pred.squeeze().cpu()) | |
# Base model 2: LR One Vs All | |
cast_features = [] | |
for actor in cast: | |
if actor in top_actors: | |
cast_features.append(str(actor)) | |
lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features])) | |
# Concatenating Outputs of base models | |
r1 = distilbert_pred[3] | |
r2 = distilbert_pred[1] | |
r3 = distilbert_pred[2] | |
distilbert_pred[1] = r1 | |
distilbert_pred[2] = r2 | |
distilbert_pred[3] = r3 | |
pred1 = distilbert_pred | |
pred2 = lr_model_pred | |
distilbert_pred = pred1.detach().numpy() | |
lr_model_pred = np.array(pred2)[0] | |
concat_features = np.concatenate((lr_model_pred,distilbert_pred)) | |
# Meta model 3: LR One Vs All | |
probs = meta_model.predict_proba([concat_features]) | |
# Preparing Output | |
out = [] | |
id2label = {0:"Action",1:"Comedy",2:"Drama",3:"Romance",4:"Thriller"} | |
i = 0 | |
for prob in probs[0]: | |
out.append([id2label[i], prob]) | |
i += 1 | |
st.write(out) |