Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 5,083 Bytes
			
			| fcc02a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import torch
import torch.nn as nn
class ContextualAlphaMask(nn.Module):
    def __init__(
            self,
            dim: int = 768,
    ):
        super(ContextualAlphaMask, self).__init__()
        self.dim = dim
        half_dim = dim // 2
        quarter_dim = dim // 4
        self.fc1 = nn.Linear(self.dim, self.dim)
        self.fc2 = nn.Linear(self.dim, half_dim)
        self.norm1 = nn.LayerNorm(half_dim)
        self.fc3 = nn.Linear(half_dim, half_dim)
        self.fc4 = nn.Linear(half_dim, quarter_dim)
        self.norm2 = nn.LayerNorm(quarter_dim)
        self.fc5 = nn.Linear(quarter_dim, quarter_dim)
        self.fc6 = nn.Linear(quarter_dim, 1)
        # set fc6  weights to near zero
        self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
        self.act_fn = nn.GELU()
    def forward(self, x):
        # x = (batch_size, 77, 768)
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.fc2(x)
        x = self.norm1(x)
        x = self.act_fn(x)
        x = self.fc3(x)
        x = self.act_fn(x)
        x = self.fc4(x)
        x = self.norm2(x)
        x = self.act_fn(x)
        x = self.fc5(x)
        x = self.act_fn(x)
        x = self.fc6(x)
        x = torch.sigmoid(x)
        return x
class ZipperModule(nn.Module):
    def __init__(
            self,
            in_size,
            in_tokens,
            out_size,
            out_tokens,
            hidden_size,
            hidden_tokens,
            use_residual=False,
    ):
        super().__init__()
        self.in_size = in_size
        self.in_tokens = in_tokens
        self.out_size = out_size
        self.out_tokens = out_tokens
        self.hidden_size = hidden_size
        self.hidden_tokens = hidden_tokens
        self.use_residual = use_residual
        self.act_fn = nn.GELU()
        self.layernorm = nn.LayerNorm(self.in_size)
        self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1)
        # act
        self.fc1 = nn.Linear(self.in_size, self.hidden_size)
        # act
        self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1)
        # act
        self.fc2 = nn.Linear(self.hidden_size, self.out_size)
    def forward(self, x):
        residual = x
        x = self.layernorm(x)
        x = self.conv1(x)
        x = self.act_fn(x)
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.conv2(x)
        x = self.act_fn(x)
        x = self.fc2(x)
        if self.use_residual:
            x = x + residual
        return x
class ZipperResampler(nn.Module):
    def __init__(
            self,
            in_size,
            in_tokens,
            out_size,
            out_tokens,
            hidden_size,
            hidden_tokens,
            num_blocks=1,
            is_conv_input=False,
    ):
        super().__init__()
        self.is_conv_input = is_conv_input
        module_list = []
        for i in range(num_blocks):
            this_in_size = in_size
            this_in_tokens = in_tokens
            this_out_size = out_size
            this_out_tokens = out_tokens
            this_hidden_size = hidden_size
            this_hidden_tokens = hidden_tokens
            use_residual = False
            # maintain middle sizes as hidden_size
            if i == 0:  # first block
                this_in_size = in_size
                this_in_tokens = in_tokens
                if num_blocks == 1:
                    this_out_size = out_size
                    this_out_tokens = out_tokens
                else:
                    this_out_size = hidden_size
                    this_out_tokens = hidden_tokens
            elif i == num_blocks - 1:  # last block
                this_out_size = out_size
                this_out_tokens = out_tokens
                if num_blocks == 1:
                    this_in_size = in_size
                    this_in_tokens = in_tokens
                else:
                    this_in_size = hidden_size
                    this_in_tokens = hidden_tokens
            else:  # middle blocks
                this_out_size = hidden_size
                this_out_tokens = hidden_tokens
                this_in_size = hidden_size
                this_in_tokens = hidden_tokens
                use_residual = True
            module_list.append(ZipperModule(
                in_size=this_in_size,
                in_tokens=this_in_tokens,
                out_size=this_out_size,
                out_tokens=this_out_tokens,
                hidden_size=this_hidden_size,
                hidden_tokens=this_hidden_tokens,
                use_residual=use_residual
            ))
        self.blocks = nn.ModuleList(module_list)
        self.ctx_alpha = ContextualAlphaMask(
            dim=out_size,
        )
    def forward(self, x):
        if self.is_conv_input:
            # flatten
            x = x.view(x.size(0), x.size(1), -1)
            # rearrange to (batch, tokens, size)
            x = x.permute(0, 2, 1)
        for block in self.blocks:
            x = block(x)
        alpha = self.ctx_alpha(x)
        return x * alpha
 | 
