import torch from torch import nn from transformers import SiglipVisionModel, SiglipVisionConfig # 384/14=27.428571428571427 is not an integer, so the actual pos embedding is 729, sqrt(729)*14=378. So the implementation uses the floor class SiglipEncoder(nn.Module): def __init__(self, vision_config): super(SiglipEncoder, self).__init__() config = SiglipVisionConfig(**vision_config) self.model = SiglipVisionModel(config) def forward(self, images): outputs = self.model(images).last_hidden_state return outputs class GLU(nn.Module): def __init__(self, args, in_features): super().__init__() self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False) self.norm1 = nn.LayerNorm(args.hidden_size) self.act1 = nn.GELU() self.act2 = nn.functional.silu self.dense_h_to_4h = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.dense_4h_to_h = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) def forward(self, x): x = self.linear_proj(x) x = self.act1(self.norm1(x)) x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) x = self.dense_4h_to_h(x) return x class Adapter(nn.Module): def __init__(self, eva_hidden_size, args): super().__init__() self.boi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float()) self.eoi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float()) self.conv = nn.Conv2d(in_channels=eva_hidden_size, out_channels=args.hidden_size, kernel_size=2, stride=2) self.linear_proj = GLU(args, args.hidden_size) def forward(self, image_emb): b, s, e = image_emb.shape # (b, 6400, 1792) grid_size = int(s**0.5) image_emb = image_emb.view(b, grid_size, grid_size, e).permute(0,3,1,2) # (b, 1792, 80, 80) image_emb = self.conv(image_emb) # (b, 4096, 40, 40) image_emb = image_emb.flatten(2).transpose(1, 2) # (b, 1600, 4096) image_emb = self.linear_proj(image_emb) # (b, 1600, 6656) image_emb = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1) return image_emb class VisionModel(torch.nn.Module): def __init__(self, config): super().__init__() self.dtype = config.torch_dtype self.vit = SiglipEncoder(config.vision_config) self.adapter = Adapter(config.vision_config['hidden_size'], config) def forward(self, image): image = image.to(self.dtype) vit_output = self.vit(image) return self.adapter(vit_output).to(self.dtype)