A-M-S commited on
Commit
27d72f6
2 Parent(s): f08cdc3 7accc69

Merge branch 'main' of https://huggingface.co/spaces/A-M-S/movie-genre

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -16,7 +16,7 @@ st.caption("Either enter Wiki URL or the Cast info of the movie. Cast will be fe
16
  wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
17
  cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
18
 
19
- model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-36819")
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
  model.to(device)
22
 
@@ -61,7 +61,7 @@ if st.button("Predict"):
61
  # Use Meta Model approach when cast information is available otherwise use DistilBERT model
62
  if len(cast)!=0:
63
  # Base Model 1: DistilBERT
64
- id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
65
  input_ids = [np.asarray(tokenized_plot['input_ids'])]
66
  attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
67
 
@@ -80,11 +80,17 @@ if st.button("Predict"):
80
 
81
  # Concatenating Outputs of base models
82
  r1 = distilbert_pred[3]
83
- r2 = distilbert_pred[1]
84
- r3 = distilbert_pred[2]
 
 
 
85
  distilbert_pred[1] = r1
86
  distilbert_pred[2] = r2
87
  distilbert_pred[3] = r3
 
 
 
88
  pred1 = distilbert_pred
89
  pred2 = lr_model_pred
90
  distilbert_pred = pred1.detach().numpy()
@@ -95,13 +101,13 @@ if st.button("Predict"):
95
  probs = meta_model.predict_proba([concat_features])
96
 
97
  # Preparing Output
98
- id2label = {0:"Action",1:"Comedy",2:"Drama",3:"Romance",4:"Thriller"}
99
  i = 0
100
  for prob in probs[0]:
101
  out.append([id2label[i], prob])
102
  i += 1
103
  else:
104
- id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
105
  input_ids = [np.asarray(tokenized_plot['input_ids'])]
106
  attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
107
 
16
  wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
17
  cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
18
 
19
+ model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-49092")
20
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
  model.to(device)
22
 
61
  # Use Meta Model approach when cast information is available otherwise use DistilBERT model
62
  if len(cast)!=0:
63
  # Base Model 1: DistilBERT
64
+ id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
65
  input_ids = [np.asarray(tokenized_plot['input_ids'])]
66
  attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
67
 
80
 
81
  # Concatenating Outputs of base models
82
  r1 = distilbert_pred[3]
83
+ r2 = distilbert_pred[5]
84
+ r3 = distilbert_pred[1]
85
+ r4 = distilbert_pred[6]
86
+ r5 = distilbert_pred[2]
87
+ r6 = distilbert_pred[4]
88
  distilbert_pred[1] = r1
89
  distilbert_pred[2] = r2
90
  distilbert_pred[3] = r3
91
+ distilbert_pred[4] = r4
92
+ distilbert_pred[5] = r5
93
+ distilbert_pred[6] = r6
94
  pred1 = distilbert_pred
95
  pred2 = lr_model_pred
96
  distilbert_pred = pred1.detach().numpy()
101
  probs = meta_model.predict_proba([concat_features])
102
 
103
  # Preparing Output
104
+ id2label = {0: "Action",1: "Comedy",2: "Crime",3: "Drama",4: "Horror",5: "Romance",6: "Thriller"}
105
  i = 0
106
  for prob in probs[0]:
107
  out.append([id2label[i], prob])
108
  i += 1
109
  else:
110
+ id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
111
  input_ids = [np.asarray(tokenized_plot['input_ids'])]
112
  attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
113