Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 4,561 Bytes
			
			| fcc02a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import torch
import torch.nn as nn
from toolkit.models.zipper_resampler import ContextualAlphaMask
# Conv1d MLP
# MLP that can alternately be used as a conv1d on dim 1
class MLPC(nn.Module):
    def __init__(
            self,
            in_dim,
            out_dim,
            hidden_dim,
            do_conv=False,
            use_residual=True
    ):
        super().__init__()
        self.do_conv = do_conv
        if use_residual:
            assert in_dim == out_dim
        # dont normalize if using conv
        if not do_conv:
            self.layernorm = nn.LayerNorm(in_dim)
        if do_conv:
            self.fc1 = nn.Conv1d(in_dim, hidden_dim, 1)
            self.fc2 = nn.Conv1d(hidden_dim, out_dim, 1)
        else:
            self.fc1 = nn.Linear(in_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.use_residual = use_residual
        self.act_fn = nn.GELU()
    def forward(self, x):
        residual = x
        if not self.do_conv:
            x = self.layernorm(x)
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.fc2(x)
        if self.use_residual:
            x = x + residual
        return x
class ZipperBlock(nn.Module):
    def __init__(
            self,
            in_size,
            in_tokens,
            out_size,
            out_tokens,
            hidden_size,
            hidden_tokens,
    ):
        super().__init__()
        self.in_size = in_size
        self.in_tokens = in_tokens
        self.out_size = out_size
        self.out_tokens = out_tokens
        self.hidden_size = hidden_size
        self.hidden_tokens = hidden_tokens
        # permute to (batch_size, out_size, in_tokens)
        self.zip_token = MLPC(
            in_dim=self.in_tokens,
            out_dim=self.out_tokens,
            hidden_dim=self.hidden_tokens,
            do_conv=True,  # no need to permute
            use_residual=False
        )
        # permute to (batch_size, out_tokens, out_size)
        # in shpae: (batch_size, in_tokens, in_size)
        self.zip_size = MLPC(
            in_dim=self.in_size,
            out_dim=self.out_size,
            hidden_dim=self.hidden_size,
            use_residual=False
        )
    def forward(self, x):
        x = self.zip_token(x)
        x = self.zip_size(x)
        return x
# CLIPFusionModule
# Fuses any size of vision and text embeddings into a single embedding.
# remaps tokens and vectors.
class CLIPFusionModule(nn.Module):
    def __init__(
            self,
            text_hidden_size: int = 768,
            text_tokens: int = 77,
            vision_hidden_size: int = 1024,
            vision_tokens: int = 257,
            num_blocks: int = 1,
    ):
        super(CLIPFusionModule, self).__init__()
        self.text_hidden_size = text_hidden_size
        self.text_tokens = text_tokens
        self.vision_hidden_size = vision_hidden_size
        self.vision_tokens = vision_tokens
        self.resampler = ZipperBlock(
            in_size=self.vision_hidden_size,
            in_tokens=self.vision_tokens,
            out_size=self.text_hidden_size,
            out_tokens=self.text_tokens,
            hidden_size=self.vision_hidden_size * 2,
            hidden_tokens=self.vision_tokens * 2
        )
        self.zipper_blocks = torch.nn.ModuleList([
            ZipperBlock(
                in_size=self.text_hidden_size * 2,
                in_tokens=self.text_tokens,
                out_size=self.text_hidden_size,
                out_tokens=self.text_tokens,
                hidden_size=self.text_hidden_size * 2,
                hidden_tokens=self.text_tokens * 2
            ) for i in range(num_blocks)
        ])
        self.ctx_alpha = ContextualAlphaMask(
            dim=self.text_hidden_size,
        )
        self.alpha = nn.Parameter(torch.zeros([text_tokens]) + 0.01)
    def forward(self, text_embeds, vision_embeds):
        # text_embeds = (batch_size, 77, 768)
        # vision_embeds = (batch_size, 257, 1024)
        # output = (batch_size, 77, 768)
        vision_embeds = self.resampler(vision_embeds)
        x = vision_embeds
        for i, block in enumerate(self.zipper_blocks):
            res = x
            x = torch.cat([text_embeds, x], dim=-1)
            x = block(x)
            x = x + res
        # alpha mask
        ctx_alpha = self.ctx_alpha(text_embeds)
        # reshape alpha to (1, 77, 1)
        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
        x = ctx_alpha * x * alpha
        x = x + text_embeds
        return x
 | 
