aaronespasa commited on
Commit
02c225a
1 Parent(s): a5e6fcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -6,6 +6,10 @@ import os
6
  import numpy as np
7
  from PIL import Image
8
  import zipfile
 
 
 
 
9
 
10
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
11
  zip_ref.extractall(".")
@@ -25,7 +29,7 @@ model = InceptionResnetV1(
25
  device=DEVICE
26
  )
27
 
28
- checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
29
  model.load_state_dict(checkpoint['model_state_dict'])
30
  model.to(DEVICE)
31
  model.eval()
@@ -52,11 +56,24 @@ def predict(input_image:Image.Image, true_label:str):
52
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
53
 
54
  # convert the face into a numpy array to be able to plot it
55
- face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
 
56
 
57
  face = face.to(DEVICE)
58
  face = face.to(torch.float32)
59
  face = face / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
60
  with torch.no_grad():
61
  output = torch.sigmoid(model(face).squeeze(0))
62
  prediction = "real" if output.item() < 0.5 else "fake"
@@ -68,7 +85,7 @@ def predict(input_image:Image.Image, true_label:str):
68
  'real': real_prediction,
69
  'fake': fake_prediction
70
  }
71
- return confidences, true_label, face_image_to_plot
72
 
73
  interface = gr.Interface(
74
  fn=predict,
@@ -79,7 +96,7 @@ interface = gr.Interface(
79
  outputs=[
80
  gr.outputs.Label(label="Class"),
81
  "text",
82
- gr.outputs.Image(label="Face")
83
  ],
84
  examples=[[examples[i]["path"], examples[i]["label"]] for i in range(10)]
85
  ).launch()
 
6
  import numpy as np
7
  from PIL import Image
8
  import zipfile
9
+ import cv2
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ from pytorch_grad_cam.utils.image import show_cam_on_image
13
 
14
  with zipfile.ZipFile("examples.zip","r") as zip_ref:
15
  zip_ref.extractall(".")
 
29
  device=DEVICE
30
  )
31
 
32
+ checkpoint = torch.load("resnetinceptionv1_epoch_32.pth")
33
  model.load_state_dict(checkpoint['model_state_dict'])
34
  model.to(DEVICE)
35
  model.eval()
 
56
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
57
 
58
  # convert the face into a numpy array to be able to plot it
59
+ prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
60
+ prev_face = prev_face.astype('uint8')
61
 
62
  face = face.to(DEVICE)
63
  face = face.to(torch.float32)
64
  face = face / 255.0
65
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
66
+
67
+ target_layers=[model.block8.branch1[-1]]
68
+ use_cuda = True if torch.cuda.is_available() else False
69
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
70
+ targets = [ClassifierOutputTarget(0)]
71
+
72
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
73
+ grayscale_cam = grayscale_cam[0, :]
74
+ visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
75
+ face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)
76
+
77
  with torch.no_grad():
78
  output = torch.sigmoid(model(face).squeeze(0))
79
  prediction = "real" if output.item() < 0.5 else "fake"
 
85
  'real': real_prediction,
86
  'fake': fake_prediction
87
  }
88
+ return confidences, true_label, face_with_mask
89
 
90
  interface = gr.Interface(
91
  fn=predict,
 
96
  outputs=[
97
  gr.outputs.Label(label="Class"),
98
  "text",
99
+ gr.outputs.Image(label="Face with Explainability")
100
  ],
101
  examples=[[examples[i]["path"], examples[i]["label"]] for i in range(10)]
102
  ).launch()