Demo750 commited on
Commit
a2d4767
1 Parent(s): a05a7bd

Update Predict.py

Browse files
Files changed (1) hide show
  1. Predict.py +53 -1
Predict.py CHANGED
@@ -8,6 +8,9 @@ import sys
8
  import joblib
9
  from DL_models import CustomResNet
10
 
 
 
 
11
  #Ad/Brand Gaze Prediction
12
 
13
  #Now the model is only able to process magazine images or images with full-page counterpages
@@ -264,4 +267,53 @@ def CNN_Prediction(adv_imgs, ctpg_imgs, ad_locations, Gaze_Type='AG'): #Gaze_Typ
264
  pred = torch.exp(pred*a_temp+b_temp) - 1
265
  gaze += pred/10
266
 
267
- return gaze
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import joblib
9
  from DL_models import CustomResNet
10
 
11
+ root = '/Users/jianpingye/Desktop/Marketing_Research/XGBoost_Gaze_Prediction_Platform/Gaze-Time-Prediction-for-Advertisement/XGBoost_Prediction_Model'
12
+ sys.path.append(root)
13
+
14
  #Ad/Brand Gaze Prediction
15
 
16
  #Now the model is only able to process magazine images or images with full-page counterpages
 
267
  pred = torch.exp(pred*a_temp+b_temp) - 1
268
  gaze += pred/10
269
 
270
+ return gaze
271
+
272
+ def HeatMap_CNN(adv_imgs, ctpg_imgs, ad_locations, Gaze_Type='AG'):
273
+ if torch.cuda.is_available():
274
+ device = 'cuda'
275
+ elif torch.backends.mps.is_available():
276
+ device = 'mps'
277
+ else:
278
+ device = 'cpu'
279
+
280
+ net = CustomResNet()
281
+ net.load_state_dict(torch.load('CNN_Gaze_Model/Fine-tune_'+Gaze_Type+'/Model_'+str(0)+'.pth',map_location=torch.device('cpu')))
282
+ net = net.to(device)
283
+ pred = net(adv_imgs/255.0,ctpg_imgs/255.0,ad_locations)
284
+
285
+ pred.backward()
286
+
287
+ # pull the gradients out of the model
288
+ gradients = net.get_activations_gradient()
289
+
290
+ # pool the gradients across the channels
291
+ pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
292
+
293
+ # get the activations of the last convolutional layer
294
+ activations = net.get_activations(adv_imgs).detach()
295
+
296
+ # weight the channels by corresponding gradients
297
+ for i in range(512):
298
+ activations[:, i, :, :] *= pooled_gradients[i]
299
+
300
+ # average the channels of the activations
301
+ heatmap = torch.mean(activations, dim=1).squeeze().to('cpu')
302
+
303
+ # relu on top of the heatmap
304
+ # expression (2) in https://arxiv.org/pdf/1610.02391.pdf
305
+ heatmap = np.maximum(heatmap, 0)
306
+
307
+ # normalize the heatmap
308
+ heatmap /= torch.max(heatmap)
309
+
310
+ img = torch.permute(adv_imgs[0],(1,2,0)).to(torch.uint8).numpy()
311
+ img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
312
+ heatmap = cv.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
313
+ heatmap = np.uint8(255 * heatmap)
314
+ heatmap = cv.applyColorMap(heatmap, cv.COLORMAP_TURBO)
315
+ superimposed_img = heatmap * 0.8 + img * 0.5
316
+ superimposed_img /= np.max(superimposed_img)
317
+ superimposed_img = np.uint8(255 * superimposed_img)
318
+
319
+ return superimposed_img