VictorSanh commited on
Commit
0bfa212
1 Parent(s): 719f253
Files changed (1) hide show
  1. vision.py +7 -7
vision.py CHANGED
@@ -192,7 +192,7 @@ class SiglipVisionModelOutput(ModelOutput):
192
 
193
 
194
  class SiglipVisionEmbeddings(nn.Module):
195
- def __init__(self, config: VMistralVisionConfig):
196
  super().__init__()
197
  self.config = config
198
  self.embed_dim = config.hidden_size
@@ -565,7 +565,7 @@ class SiglipMLP(nn.Module):
565
 
566
  # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
567
  class SiglipEncoderLayer(nn.Module):
568
- def __init__(self, config: VMistralVisionConfig):
569
  super().__init__()
570
  self.embed_dim = config.hidden_size
571
  self.self_attn = (
@@ -1001,7 +1001,7 @@ class SiglipEncoder(nn.Module):
1001
 
1002
 
1003
  class SiglipVisionTransformer(nn.Module):
1004
- def __init__(self, config: VMistralVisionConfig):
1005
  super().__init__()
1006
  self.config = config
1007
  embed_dim = config.hidden_size
@@ -1012,7 +1012,7 @@ class SiglipVisionTransformer(nn.Module):
1012
  self.head = SiglipMultiheadAttentionPoolingHead(config)
1013
 
1014
  # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1015
- # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=VMistralVisionConfig)
1016
  def forward(
1017
  self,
1018
  pixel_values,
@@ -1058,7 +1058,7 @@ class SiglipVisionTransformer(nn.Module):
1058
  class SiglipMultiheadAttentionPoolingHead(nn.Module):
1059
  """Multihead Attention Pooling."""
1060
 
1061
- def __init__(self, config: VMistralVisionConfig):
1062
  super().__init__()
1063
 
1064
  self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@@ -1084,7 +1084,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1084
  # SIGLIP_START_DOCSTRING,
1085
  # )
1086
  class SiglipVisionModel(nn.Module):
1087
- def __init__(self, config: VMistralVisionConfig):
1088
  super().__init__()
1089
 
1090
  self.vision_model = SiglipVisionTransformer(config)
@@ -1096,7 +1096,7 @@ class SiglipVisionModel(nn.Module):
1096
  # return self.vision_model.embeddings.patch_embedding
1097
 
1098
  # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1099
- # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=VMistralVisionConfig)
1100
  def forward(
1101
  self,
1102
  pixel_values,
 
192
 
193
 
194
  class SiglipVisionEmbeddings(nn.Module):
195
+ def __init__(self, config: Img2HTMLVisionConfig):
196
  super().__init__()
197
  self.config = config
198
  self.embed_dim = config.hidden_size
 
565
 
566
  # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
567
  class SiglipEncoderLayer(nn.Module):
568
+ def __init__(self, config: Img2HTMLVisionConfig):
569
  super().__init__()
570
  self.embed_dim = config.hidden_size
571
  self.self_attn = (
 
1001
 
1002
 
1003
  class SiglipVisionTransformer(nn.Module):
1004
+ def __init__(self, config: Img2HTMLVisionConfig):
1005
  super().__init__()
1006
  self.config = config
1007
  embed_dim = config.hidden_size
 
1012
  self.head = SiglipMultiheadAttentionPoolingHead(config)
1013
 
1014
  # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1015
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Img2HTMLVisionConfig)
1016
  def forward(
1017
  self,
1018
  pixel_values,
 
1058
  class SiglipMultiheadAttentionPoolingHead(nn.Module):
1059
  """Multihead Attention Pooling."""
1060
 
1061
+ def __init__(self, config: Img2HTMLVisionConfig):
1062
  super().__init__()
1063
 
1064
  self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
 
1084
  # SIGLIP_START_DOCSTRING,
1085
  # )
1086
  class SiglipVisionModel(nn.Module):
1087
+ def __init__(self, config: Img2HTMLVisionConfig):
1088
  super().__init__()
1089
 
1090
  self.vision_model = SiglipVisionTransformer(config)
 
1096
  # return self.vision_model.embeddings.patch_embedding
1097
 
1098
  # @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1099
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Img2HTMLVisionConfig)
1100
  def forward(
1101
  self,
1102
  pixel_values,