jw2yang commited on
Commit
a0e100e
1 Parent(s): 58b8c28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -21,7 +21,7 @@ model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], us
21
  url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth'
22
  checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
23
  model.load_state_dict(checkpoint["model"])
24
- model = model.cuda(); model.eval()
25
 
26
  '''
27
  build data transform
@@ -65,32 +65,32 @@ def show_cam_on_image(img: np.ndarray,
65
  def classify_image(inp):
66
 
67
  img_t = eval_transforms(inp)
68
- img_d = display_transforms(inp).permute(1, 2, 0).cpu().numpy()
69
  print(img_d.min(), img_d.max())
70
 
71
- prediction = model(img_t.unsqueeze(0).cuda()).softmax(-1).flatten()
72
 
73
  modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
74
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
75
- modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
76
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
77
  cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
78
 
79
  modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
80
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
81
- modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
82
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
83
  cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
84
 
85
  modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
86
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
87
- modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
88
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
89
  cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
90
 
91
  modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
92
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
93
- modulator = modulator.squeeze(1).detach().permute(1, 2, 0).cpu().numpy()
94
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
95
  cam3 = show_cam_on_image(img_d, modulator, use_rgb=True)
96
 
21
  url = 'https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_iso_16.pth'
22
  checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
23
  model.load_state_dict(checkpoint["model"])
24
+ model.eval()
25
 
26
  '''
27
  build data transform
65
  def classify_image(inp):
66
 
67
  img_t = eval_transforms(inp)
68
+ img_d = display_transforms(inp).permute(1, 2, 0).numpy()
69
  print(img_d.min(), img_d.max())
70
 
71
+ prediction = model(img_t.unsqueeze(0)).softmax(-1).flatten()
72
 
73
  modulator = model.layers[0].blocks[2].modulation.modulator.norm(2, 1, keepdim=True)
74
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
75
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
76
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
77
  cam0 = show_cam_on_image(img_d, modulator, use_rgb=True)
78
 
79
  modulator = model.layers[0].blocks[5].modulation.modulator.norm(2, 1, keepdim=True)
80
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
81
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
82
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
83
  cam1 = show_cam_on_image(img_d, modulator, use_rgb=True)
84
 
85
  modulator = model.layers[0].blocks[8].modulation.modulator.norm(2, 1, keepdim=True)
86
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
87
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
88
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
89
  cam2 = show_cam_on_image(img_d, modulator, use_rgb=True)
90
 
91
  modulator = model.layers[0].blocks[11].modulation.modulator.norm(2, 1, keepdim=True)
92
  modulator = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(modulator)
93
+ modulator = modulator.squeeze(1).detach().permute(1, 2, 0).numpy()
94
  modulator = (modulator - modulator.min()) / (modulator.max() - modulator.min())
95
  cam3 = show_cam_on_image(img_d, modulator, use_rgb=True)
96