""" This code is borrowed from https://github.com/buptLinfy/ZSE-SBIR """ import math import copy import torch import torch.nn as nn import torch.nn.functional as F def clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) class LayerNorm(nn.Module): def __init__(self, features, eps=1e-6): super(LayerNorm, self).__init__() self.a = nn.Parameter(torch.ones(features)) self.b = nn.Parameter(torch.zeros(features)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.a * (x - mean) / (std + self.eps) + self.b class AddAndNorm(nn.Module): def __init__(self, size, dropout): super(AddAndNorm, self).__init__() self.norm = LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, y): return self.norm(x + self.dropout(y)) class EncoderLayer(nn.Module): "Encoder is made up of self-attn and feed forward (defined below)" def __init__(self, size, self_attn, feed_forward, dropout): super(EncoderLayer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.sublayer = clones(AddAndNorm(size, dropout), 2) self.size = size def forward(self, q, k, v, mask): x = self.sublayer[0](v, self.self_attn(q, k, v, mask)) x = self.sublayer[1](x, self.feed_forward(x)) return x class Encoder(nn.Module): def __init__(self, layer, N): super(Encoder, self).__init__() self.layers = clones(layer, N) self.layer1 = clones(layer, N) self.layer2 = clones(layer, N) def forward(self, x_im, x_text, mask): for layer1, layer2 in zip(self.layer1, self.layer2): # 在此交换Q exchange Q here # layer1 处理 sk - layer1 process sk # x_text1 = layer1(x_text, x_im, x_text, mask) # layer2 处理 im - layer2 process im x_im = layer2(x_im, x_text, x_im, mask) # x_sk = x_text1 return x_im def attention(query, key, value, dropout=None, mask=None, pos=None): """ dk = dv = dmodel/h = 64,h=8 """ d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout=0.1): "Take in model size and number of heads." super(MultiHeadedAttention, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, mask=None): """ :param query: size(batch,seq,512) :param key: :param value: :param mask: :return: """ if mask is not None: # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k # size(batch,h,seq,dk) query, key, value = \ [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for lin, x in zip(self.linears, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x) class PositionwiseFeedForward(nn.Module): """ d_model = 512 d_ff = 2048 为论文中数值 """ def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w_2(self.dropout(F.relu(self.w_1(x)))) class Cross_Attention(nn.Module): def __init__(self, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1): #(self, args, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1): super(Cross_Attention, self).__init__() multi_head_attention = MultiHeadedAttention(h, d_model) ffn = PositionwiseFeedForward(d_model, d_ff, dropout) encoderLayer = EncoderLayer(d_model, multi_head_attention, ffn, dropout) self.encoder = Encoder(encoderLayer, n) self.text_projection = nn.Linear(512, d_model) def forward(self, x_patch,x_text): length = x_text.shape[0] x_text = self.text_projection(x_text) x_sketch= self.encoder(x_patch, x_text, None) # 不要mask - don't mask return x_sketch