A-M-S commited on
Commit
8fd825b
2 Parent(s): 3788106 7496ff3

Merge branch 'main' of https://huggingface.co/spaces/A-M-S/movie-genre

Browse files
Files changed (1) hide show
  1. app.py +78 -56
app.py CHANGED
@@ -11,9 +11,10 @@ from utility import Utility
11
 
12
  st.title("Movie Genre Predictor")
13
 
14
- st.subheader("Enter the text you'd like to analyze.")
15
  text = st.text_input('Enter plot of the movie')
16
- wiki_url = st.text_input("Enter wikipedia url of the movie (Needed for fetching the cast information)")
 
 
17
 
18
  model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819")
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -28,69 +29,90 @@ meta_model = pickle.load(open("models/meta_model","rb"))
28
 
29
  utility = Utility()
30
  preprocess = Preprocess()
 
31
 
32
  if st.button("Predict"):
33
  cast = []
34
- if len(wiki_url)!=0:
35
- cast_wiki = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).section("Cast")
36
- cast_names = [val.split(" as ")[0] for val in cast_wiki.split("\n")]
37
- for actor in cast_names[:5]:
 
 
38
  try:
39
- cast.append(wikipedia.page(title=actor).pageid)
40
  except:
41
- search_results = wikipedia.search(actor,results=2)
42
- try:
43
- cast.append(wikipedia.page(title=search_results[0]).pageid)
44
- except:
45
- try:
46
- cast.append(wikipedia.page(title=search_results(actor)[1]).pageid)
47
- except:
48
- pass
49
 
50
- st.write("Wiki Ids of Top 5 Cast:",cast)
 
 
 
51
  st.write("Genre: ")
52
 
53
  clean_plot = preprocess.apply(text)
54
 
55
- # Base Model 1: DistilBERT
56
- id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
57
- input_ids = [np.asarray(tokenized_plot['input_ids'])]
58
- attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
59
-
60
- y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
61
- pred = torch.FloatTensor(y_pred['logits'][0])
62
-
63
- sigmoid = torch.nn.Sigmoid()
64
- distilbert_pred = sigmoid(pred.squeeze().cpu())
65
-
66
- # Base model 2: LR One Vs All
67
- cast_features = []
68
- for actor in cast:
69
- if actor in top_actors:
70
- cast_features.append(str(actor))
71
- lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features]))
72
-
73
- # Concatenating Outputs of base models
74
- r1 = distilbert_pred[3]
75
- r2 = distilbert_pred[1]
76
- r3 = distilbert_pred[2]
77
- distilbert_pred[1] = r1
78
- distilbert_pred[2] = r2
79
- distilbert_pred[3] = r3
80
- pred1 = distilbert_pred
81
- pred2 = lr_model_pred
82
- distilbert_pred = pred1.detach().numpy()
83
- lr_model_pred = np.array(pred2)[0]
84
- concat_features = np.concatenate((lr_model_pred,distilbert_pred))
85
-
86
- # Meta model 3: LR One Vs All
87
- probs = meta_model.predict_proba([concat_features])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Preparing Output
90
- out = []
91
- id2label = {0:"Action",1:"Comedy",2:"Drama",3:"Romance",4:"Thriller"}
92
- i = 0
93
- for prob in probs[0]:
94
- out.append([id2label[i], prob])
95
- i += 1
96
  st.write(out)
 
11
 
12
  st.title("Movie Genre Predictor")
13
 
 
14
  text = st.text_input('Enter plot of the movie')
15
+ st.caption("Either enter Wiki URL or the Cast info of the movie. Cast will be fetched from the Wiki page if cast is not provided")
16
+ wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
17
+ cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
18
 
19
  model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819")
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
29
 
30
  utility = Utility()
31
  preprocess = Preprocess()
32
+ out = []
33
 
34
  if st.button("Predict"):
35
  cast = []
36
+ if len(wiki_url)!=0 and len(cast_input)==0:
37
+ html_page = wikipedia.page(title=wiki_url.split("/")[-1].replace("_"," "), auto_suggest=False).html()
38
+ cast_wiki = html_page.split(" title=\"Edit section: Cast\">edit</a>")[-1]
39
+ anchor_tags = cast_wiki.split("<a href=")[1:6]
40
+ top5_cast_links = [val.split("\"")[1] for val in anchor_tags]
41
+ for actor in top5_cast_links:
42
  try:
43
+ cast.append(wikipedia.page(title=actor.split("/")[-1].replace("_"," ")).pageid)
44
  except:
45
+ pass
46
+ else:
47
+ if ", " in cast_input:
48
+ cast = cast_input.split(", ")
49
+ else:
50
+ cast = cast_input.split(",")
 
 
51
 
52
+ cast_str = ""
53
+ for actor in cast:
54
+ cast_str += actor + ", "
55
+ st.write("Wiki Ids of Top 5 Cast:",cast_str)
56
  st.write("Genre: ")
57
 
58
  clean_plot = preprocess.apply(text)
59
 
60
+ # Use Meta Model approach when cast information is available otherwise use DistilBERT model
61
+ if len(cast)!=0:
62
+ # Base Model 1: DistilBERT
63
+ id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
64
+ input_ids = [np.asarray(tokenized_plot['input_ids'])]
65
+ attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
66
+
67
+ y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
68
+ pred = torch.FloatTensor(y_pred['logits'][0])
69
+
70
+ sigmoid = torch.nn.Sigmoid()
71
+ distilbert_pred = sigmoid(pred.squeeze().cpu())
72
+
73
+ # Base model 2: LR One Vs All
74
+ cast_features = []
75
+ for actor in cast:
76
+ if actor in top_actors:
77
+ cast_features.append(str(actor))
78
+ lr_model_pred = lr_model.predict_proba(cast_mlb.transform([cast_features]))
79
+
80
+ # Concatenating Outputs of base models
81
+ r1 = distilbert_pred[3]
82
+ r2 = distilbert_pred[1]
83
+ r3 = distilbert_pred[2]
84
+ distilbert_pred[1] = r1
85
+ distilbert_pred[2] = r2
86
+ distilbert_pred[3] = r3
87
+ pred1 = distilbert_pred
88
+ pred2 = lr_model_pred
89
+ distilbert_pred = pred1.detach().numpy()
90
+ lr_model_pred = np.array(pred2)[0]
91
+ concat_features = np.concatenate((lr_model_pred,distilbert_pred))
92
+
93
+ # Meta model 3: LR One Vs All
94
+ probs = meta_model.predict_proba([concat_features])
95
+
96
+ # Preparing Output
97
+ id2label = {0:"Action",1:"Comedy",2:"Drama",3:"Romance",4:"Thriller"}
98
+ i = 0
99
+ for prob in probs[0]:
100
+ out.append([id2label[i], prob])
101
+ i += 1
102
+ else:
103
+ id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
104
+ input_ids = [np.asarray(tokenized_plot['input_ids'])]
105
+ attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
106
+
107
+ y_pred = model(torch.IntTensor(input_ids), torch.IntTensor(attention_mask))
108
+ pred = torch.FloatTensor(y_pred['logits'][0])
109
+
110
+ sigmoid = torch.nn.Sigmoid()
111
+ probs = sigmoid(pred.squeeze().cpu())
112
+
113
+ i = 0
114
+ for prob in probs:
115
+ out.append([id2label[i], np.asscalar(prob)])
116
+ i += 1
117
 
 
 
 
 
 
 
 
118
  st.write(out)