jw2yang commited on
Commit
6b79e13
·
1 Parent(s): d71a454

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -70,25 +70,25 @@ def classify_image(inp):
70
 
71
  prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten()
72
 
73
- modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
74
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
75
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
76
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
77
  cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
78
 
79
- modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
80
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
81
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
82
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
83
  cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
84
 
85
- modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
86
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
87
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
88
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
89
  cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
90
 
91
- modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
92
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
93
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
94
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
@@ -107,16 +107,16 @@ gr.Interface(
107
  outputs=[
108
  gr.outputs.Image(
109
  type="pil",
110
- label="Modulator at layer 3"),
111
  gr.outputs.Image(
112
  type="pil",
113
- label="Modulator at layer 6"),
114
  gr.outputs.Image(
115
  type="pil",
116
- label="Modulator at layer 9"),
117
  gr.outputs.Image(
118
  type="pil",
119
- label="Modulator at layer 12"),
120
  label,
121
  ],
122
  examples=[["./donut.png"], ["./horses.png"], ["./pencil.png"]],
 
70
 
71
  prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten()
72
 
73
+ modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
74
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
75
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
76
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
77
  cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
78
 
79
+ modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
80
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
81
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
82
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
83
  cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
84
 
85
+ modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
86
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
87
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
88
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
89
  cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
90
 
91
+ modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
92
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
93
  modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
94
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
 
107
  outputs=[
108
  gr.outputs.Image(
109
  type="pil",
110
+ label="Modulator at layer 12"),
111
  gr.outputs.Image(
112
  type="pil",
113
+ label="Modulator at layer 9"),
114
  gr.outputs.Image(
115
  type="pil",
116
+ label="Modulator at layer 6"),
117
  gr.outputs.Image(
118
  type="pil",
119
+ label="Modulator at layer 3"),
120
  label,
121
  ],
122
  examples=[["./donut.png"], ["./horses.png"], ["./pencil.png"]],