|
import torch |
|
from torch import nn |
|
from transformers import SiglipVisionModel, SiglipVisionConfig |
|
|
|
|
|
|
|
class SiglipEncoder(nn.Module): |
|
def __init__(self, vision_config): |
|
super(SiglipEncoder, self).__init__() |
|
|
|
config = SiglipVisionConfig(**vision_config) |
|
self.model = SiglipVisionModel(config) |
|
|
|
def forward(self, images): |
|
outputs = self.model(images).last_hidden_state |
|
return outputs |
|
|
|
|
|
class GLU(nn.Module): |
|
def __init__(self, args, in_features): |
|
super().__init__() |
|
self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False) |
|
self.norm1 = nn.LayerNorm(args.hidden_size) |
|
self.act1 = nn.GELU() |
|
self.act2 = nn.functional.silu |
|
self.dense_h_to_4h = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) |
|
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) |
|
self.dense_4h_to_h = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) |
|
|
|
def forward(self, x): |
|
x = self.linear_proj(x) |
|
x = self.act1(self.norm1(x)) |
|
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) |
|
x = self.dense_4h_to_h(x) |
|
return x |
|
|
|
|
|
class Adapter(nn.Module): |
|
def __init__(self, eva_hidden_size, args): |
|
super().__init__() |
|
self.boi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float()) |
|
self.eoi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float()) |
|
self.conv = nn.Conv2d(in_channels=eva_hidden_size, out_channels=args.hidden_size, kernel_size=2, stride=2) |
|
self.linear_proj = GLU(args, args.hidden_size) |
|
|
|
def forward(self, image_emb): |
|
b, s, e = image_emb.shape |
|
grid_size = int(s**0.5) |
|
image_emb = image_emb.view(b, grid_size, grid_size, e).permute(0,3,1,2) |
|
image_emb = self.conv(image_emb) |
|
image_emb = image_emb.flatten(2).transpose(1, 2) |
|
image_emb = self.linear_proj(image_emb) |
|
image_emb = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1) |
|
return image_emb |
|
|
|
|
|
class VisionModel(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dtype = config.torch_dtype |
|
self.vit = SiglipEncoder(config.vision_config) |
|
self.adapter = Adapter(config.vision_config['hidden_size'], config) |
|
|
|
def forward(self, image): |
|
image = image.to(self.dtype) |
|
vit_output = self.vit(image) |
|
return self.adapter(vit_output).to(self.dtype) |
|
|