Spaces:
Runtime error
Runtime error
Merge branch 'main' of https://huggingface.co/spaces/A-M-S/movie-genre
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ st.caption("Either enter Wiki URL or the Cast info of the movie. Cast will be fe
|
|
16 |
wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
|
17 |
cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
|
18 |
|
19 |
-
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-
|
20 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
model.to(device)
|
22 |
|
@@ -61,7 +61,7 @@ if st.button("Predict"):
|
|
61 |
# Use Meta Model approach when cast information is available otherwise use DistilBERT model
|
62 |
if len(cast)!=0:
|
63 |
# Base Model 1: DistilBERT
|
64 |
-
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
|
65 |
input_ids = [np.asarray(tokenized_plot['input_ids'])]
|
66 |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
|
67 |
|
@@ -80,11 +80,17 @@ if st.button("Predict"):
|
|
80 |
|
81 |
# Concatenating Outputs of base models
|
82 |
r1 = distilbert_pred[3]
|
83 |
-
r2 = distilbert_pred[
|
84 |
-
r3 = distilbert_pred[
|
|
|
|
|
|
|
85 |
distilbert_pred[1] = r1
|
86 |
distilbert_pred[2] = r2
|
87 |
distilbert_pred[3] = r3
|
|
|
|
|
|
|
88 |
pred1 = distilbert_pred
|
89 |
pred2 = lr_model_pred
|
90 |
distilbert_pred = pred1.detach().numpy()
|
@@ -95,13 +101,13 @@ if st.button("Predict"):
|
|
95 |
probs = meta_model.predict_proba([concat_features])
|
96 |
|
97 |
# Preparing Output
|
98 |
-
id2label = {0:"Action",1:"Comedy",2:"
|
99 |
i = 0
|
100 |
for prob in probs[0]:
|
101 |
out.append([id2label[i], prob])
|
102 |
i += 1
|
103 |
else:
|
104 |
-
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller"])
|
105 |
input_ids = [np.asarray(tokenized_plot['input_ids'])]
|
106 |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
|
107 |
|
16 |
wiki_url = st.text_input("Enter Wiki URL of the movie (Needed for fetching the cast information)")
|
17 |
cast_input = st.text_input("Enter Wiki IDs of the cast (Should be separated by comma)")
|
18 |
|
19 |
+
model = AutoModelForSequenceClassification.from_pretrained("./checkpoint-49092")
|
20 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
model.to(device)
|
22 |
|
61 |
# Use Meta Model approach when cast information is available otherwise use DistilBERT model
|
62 |
if len(cast)!=0:
|
63 |
# Base Model 1: DistilBERT
|
64 |
+
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
|
65 |
input_ids = [np.asarray(tokenized_plot['input_ids'])]
|
66 |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
|
67 |
|
80 |
|
81 |
# Concatenating Outputs of base models
|
82 |
r1 = distilbert_pred[3]
|
83 |
+
r2 = distilbert_pred[5]
|
84 |
+
r3 = distilbert_pred[1]
|
85 |
+
r4 = distilbert_pred[6]
|
86 |
+
r5 = distilbert_pred[2]
|
87 |
+
r6 = distilbert_pred[4]
|
88 |
distilbert_pred[1] = r1
|
89 |
distilbert_pred[2] = r2
|
90 |
distilbert_pred[3] = r3
|
91 |
+
distilbert_pred[4] = r4
|
92 |
+
distilbert_pred[5] = r5
|
93 |
+
distilbert_pred[6] = r6
|
94 |
pred1 = distilbert_pred
|
95 |
pred2 = lr_model_pred
|
96 |
distilbert_pred = pred1.detach().numpy()
|
101 |
probs = meta_model.predict_proba([concat_features])
|
102 |
|
103 |
# Preparing Output
|
104 |
+
id2label = {0: "Action",1: "Comedy",2: "Crime",3: "Drama",4: "Horror",5: "Romance",6: "Thriller"}
|
105 |
i = 0
|
106 |
for prob in probs[0]:
|
107 |
out.append([id2label[i], prob])
|
108 |
i += 1
|
109 |
else:
|
110 |
+
id2label, label2id, tokenizer, tokenized_plot = utility.tokenize(clean_plot, ["Action","Drama", "Romance", "Comedy", "Thriller","Crime","Horror"])
|
111 |
input_ids = [np.asarray(tokenized_plot['input_ids'])]
|
112 |
attention_mask = [np.asarray(tokenized_plot['attention_mask'])]
|
113 |
|