VictorSanh commited on
Commit
2ca24ef
1 Parent(s): 9947e3f

ops in fp16

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +11 -4
modeling_siglip.py CHANGED
@@ -95,10 +95,11 @@ def _trunc_normal_(tensor, mean, std, a, b):
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
- if tensor.dtype == torch.bfloat16:
 
99
  tensor = tensor.to(torch.float32)
100
  tensor.erfinv_()
101
- tensor = tensor.to(torch.bfloat16)
102
  else:
103
  tensor.erfinv_()
104
 
@@ -107,7 +108,13 @@ def _trunc_normal_(tensor, mean, std, a, b):
107
  tensor.add_(mean)
108
 
109
  # Clamp to ensure it's in the proper range
110
- tensor.clamp_(min=a, max=b)
 
 
 
 
 
 
111
 
112
 
113
  def trunc_normal_tf_(
@@ -732,7 +739,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
732
  nn.init.normal_(module.attention.in_proj_weight.data)
733
  nn.init.zeros_(module.attention.in_proj_bias.data)
734
  elif isinstance(module, SiglipModel):
735
- logit_scale_init = torch.log(torch.tensor(1.0))
736
  module.logit_scale.data.fill_(logit_scale_init)
737
  module.logit_bias.data.zero_()
738
  elif isinstance(module, (nn.Linear, nn.Conv2d)):
 
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
+ if tensor.dtype == torch.float16:
99
+ # The `erfinv_` op is not (yet?) defined in float16
100
  tensor = tensor.to(torch.float32)
101
  tensor.erfinv_()
102
+ tensor = tensor.to(torch.float16)
103
  else:
104
  tensor.erfinv_()
105
 
 
108
  tensor.add_(mean)
109
 
110
  # Clamp to ensure it's in the proper range
111
+ if tensor.dtype == torch.float16:
112
+ # The `clamp_` op is not (yet?) defined in float16
113
+ tensor = tensor.to(torch.float32)
114
+ tensor.clamp_(min=a, max=b)
115
+ tensor = tensor.to(torch.float16)
116
+ else:
117
+ tensor.clamp_(min=a, max=b)
118
 
119
 
120
  def trunc_normal_tf_(
 
739
  nn.init.normal_(module.attention.in_proj_weight.data)
740
  nn.init.zeros_(module.attention.in_proj_bias.data)
741
  elif isinstance(module, SiglipModel):
742
+ logit_scale_init = torch.tensor(0.0)
743
  module.logit_scale.data.fill_(logit_scale_init)
744
  module.logit_bias.data.zero_()
745
  elif isinstance(module, (nn.Linear, nn.Conv2d)):