ahmedbrs's picture
first
254fdf2
"""
This code is borrowed from https://github.com/buptLinfy/ZSE-SBIR
"""
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a = nn.Parameter(torch.ones(features))
self.b = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a * (x - mean) / (std + self.eps) + self.b
class AddAndNorm(nn.Module):
def __init__(self, size, dropout):
super(AddAndNorm, self).__init__()
self.norm = LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y):
return self.norm(x + self.dropout(y))
class EncoderLayer(nn.Module):
"Encoder is made up of self-attn and feed forward (defined below)"
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clones(AddAndNorm(size, dropout), 2)
self.size = size
def forward(self, q, k, v, mask):
x = self.sublayer[0](v, self.self_attn(q, k, v, mask))
x = self.sublayer[1](x, self.feed_forward(x))
return x
class Encoder(nn.Module):
def __init__(self, layer, N):
super(Encoder, self).__init__()
self.layers = clones(layer, N)
self.layer1 = clones(layer, N)
self.layer2 = clones(layer, N)
def forward(self, x_im, x_text, mask):
for layer1, layer2 in zip(self.layer1, self.layer2):
# 在此交换Q exchange Q here
# layer1 处理 sk - layer1 process sk
# x_text1 = layer1(x_text, x_im, x_text, mask)
# layer2 处理 im - layer2 process im
x_im = layer2(x_im, x_text, x_im, mask)
# x_sk = x_text1
return x_im
def attention(query, key, value, dropout=None, mask=None, pos=None):
"""
dk = dv = dmodel/h = 64,h=8
"""
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
"""
:param query: size(batch,seq,512)
:param key:
:param value:
:param mask:
:return:
"""
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
# size(batch,h,seq,dk)
query, key, value = \
[lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
class PositionwiseFeedForward(nn.Module):
"""
d_model = 512
d_ff = 2048 为论文中数值
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class Cross_Attention(nn.Module):
def __init__(self, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1): #(self, args, h=8, n=1, d_model=768, d_ff=1024, dropout=0.1):
super(Cross_Attention, self).__init__()
multi_head_attention = MultiHeadedAttention(h, d_model)
ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
encoderLayer = EncoderLayer(d_model, multi_head_attention, ffn, dropout)
self.encoder = Encoder(encoderLayer, n)
self.text_projection = nn.Linear(512, d_model)
def forward(self, x_patch,x_text):
length = x_text.shape[0]
x_text = self.text_projection(x_text)
x_sketch= self.encoder(x_patch, x_text, None) # 不要mask - don't mask
return x_sketch