xinyu1205 commited on
Commit
f623175
1 Parent(s): a71d73e

Update models/tag2text.py

Browse files
Files changed (1) hide show
  1. models/tag2text.py +14 -2
models/tag2text.py CHANGED
@@ -26,7 +26,14 @@ def read_json(rpath):
26
  with open(rpath, 'r') as f:
27
  return json.load(f)
28
 
 
 
29
  delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
 
 
 
 
 
30
 
31
  class Tag2Text_Caption(nn.Module):
32
  def __init__(self,
@@ -36,7 +43,7 @@ class Tag2Text_Caption(nn.Module):
36
  vit_grad_ckpt = False,
37
  vit_ckpt_layer = 0,
38
  prompt = 'a picture of ',
39
- threshold = 0.7,
40
  ):
41
  """
42
  Args:
@@ -105,6 +112,10 @@ class Tag2Text_Caption(nn.Module):
105
  tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
106
  self.tag_array = tra_array
107
 
 
 
 
 
108
  def del_selfattention(self):
109
  del self.vision_multi.embeddings
110
  for layer in self.vision_multi.encoder.layer:
@@ -130,7 +141,8 @@ class Tag2Text_Caption(nn.Module):
130
 
131
  logits = self.fc(mlr_tagembedding[0])
132
 
133
- targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
 
134
 
135
  tag = targets.cpu().numpy()
136
  tag[:,delete_tag_index] = 0
 
26
  with open(rpath, 'r') as f:
27
  return json.load(f)
28
 
29
+ # delete some tags that may disturb captioning
30
+ # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
31
  delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
32
+
33
+ # adjust thresholds for some tags
34
+ # default threshold: 0.68
35
+ # 2701: "person"; 2828: "man"; 1167: "woman";
36
+ tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
37
 
38
  class Tag2Text_Caption(nn.Module):
39
  def __init__(self,
 
43
  vit_grad_ckpt = False,
44
  vit_ckpt_layer = 0,
45
  prompt = 'a picture of ',
46
+ threshold = 0.68,
47
  ):
48
  """
49
  Args:
 
112
  tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
113
  self.tag_array = tra_array
114
 
115
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
116
+ for key,value in tag_thrshold.items():
117
+ self.class_threshold[key] = value
118
+
119
  def del_selfattention(self):
120
  del self.vision_multi.embeddings
121
  for layer in self.vision_multi.encoder.layer:
 
141
 
142
  logits = self.fc(mlr_tagembedding[0])
143
 
144
+ # targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
145
+ targets = torch.where(torch.sigmoid(logits) > self.class_threshold.to(image.device) , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
146
 
147
  tag = targets.cpu().numpy()
148
  tag[:,delete_tag_index] = 0