sayakpaul HF staff commited on
Commit
dad7fe2
1 Parent(s): aac57a1

fix: constants.

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -9,12 +9,12 @@ from torchvision.models.feature_extraction import create_feature_extractor
9
  from torchvision.transforms import functional as F
10
  import glob
11
 
12
- cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
13
- transform = timm.data.create_transform(
14
- **timm.data.resolve_data_config(cait_model.pretrained_cfg)
15
  )
16
 
17
- patch_size = 16
18
 
19
 
20
  def create_attn_extractor(block_id=0):
@@ -23,7 +23,7 @@ def create_attn_extractor(block_id=0):
23
  https://github.com/huggingface/pytorch-image-models/discussions/926
24
  """
25
  feature_extractor = create_feature_extractor(
26
- cait_model,
27
  return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
28
  tracer_kwargs={"leaf_modules": [PatchEmbed]},
29
  )
@@ -34,8 +34,8 @@ def get_cls_attention_map(
34
  image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
35
  ):
36
  """Prepares attention maps so that they can be visualized."""
37
- w_featmap = image.shape[3] // patch_size
38
- h_featmap = image.shape[2] // patch_size
39
 
40
  attention_scores = attn_score_dict[block_key]
41
  nh = attention_scores.shape[1] # Number of attention heads.
@@ -51,7 +51,7 @@ def get_cls_attention_map(
51
  # Resize the attention patches to 224x224 (224: 14x16)
52
  attentions = F.resize(
53
  attentions,
54
- size=(h_featmap * patch_size, w_featmap * patch_size),
55
  interpolation=3,
56
  )
57
  print(attentions.shape)
@@ -85,8 +85,8 @@ def serialize_images(processed_map):
85
  def generate_class_attn_map(image, block_id=0):
86
  """Collates the above utilities together for generating
87
  a class attention map."""
88
- image_tensor = transform(image).unsqueeze(0)
89
- feature_extractor = create_attn_extractor(cait_model, block_id)
90
 
91
  with torch.no_grad():
92
  out = feature_extractor(image_tensor)
 
9
  from torchvision.transforms import functional as F
10
  import glob
11
 
12
+ CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
13
+ TRANSFORM = timm.data.create_transform(
14
+ **timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
15
  )
16
 
17
+ PATCH_SIZE = 16
18
 
19
 
20
  def create_attn_extractor(block_id=0):
 
23
  https://github.com/huggingface/pytorch-image-models/discussions/926
24
  """
25
  feature_extractor = create_feature_extractor(
26
+ CAIT_MODEL,
27
  return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
28
  tracer_kwargs={"leaf_modules": [PatchEmbed]},
29
  )
 
34
  image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
35
  ):
36
  """Prepares attention maps so that they can be visualized."""
37
+ w_featmap = image.shape[3] // PATCH_SIZE
38
+ h_featmap = image.shape[2] // PATCH_SIZE
39
 
40
  attention_scores = attn_score_dict[block_key]
41
  nh = attention_scores.shape[1] # Number of attention heads.
 
51
  # Resize the attention patches to 224x224 (224: 14x16)
52
  attentions = F.resize(
53
  attentions,
54
+ size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
55
  interpolation=3,
56
  )
57
  print(attentions.shape)
 
85
  def generate_class_attn_map(image, block_id=0):
86
  """Collates the above utilities together for generating
87
  a class attention map."""
88
+ image_tensor = TRANSFORM(image).unsqueeze(0)
89
+ feature_extractor = create_attn_extractor(block_id)
90
 
91
  with torch.no_grad():
92
  out = feature_extractor(image_tensor)