DRv2 / Model /fur_rl /models /transformer.py
Zhonathon's picture
update all file v1
aa7fb02
import torch
import copy
import torch.nn as nn
import torch.nn .functional as F
import numpy as np
import math
#helpers
def clone(module,N):
'''copy the given module N times'''
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
def subsequent_mask(size):
attn_shape=(1,size,size)
subsequent_mask=np.triu(np.ones(attn_shape),k=1).astype(bool)
return torch.from_numpy(subsequent_mask)==False
def attention(query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
self.d_k = d_model // h
self.h = h
self.linears = clone(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
mask = mask.unsqueeze(1)
'''print('q:',query)
print('k:',key)
print('v:',value)'''
nbatchs = query.size(0)
query, key, value = [l(x).view(nbatchs, -1, self.h, self.d_k).transpose(1, 2) \
for l, x in zip(self.linears, (query, key, value))]
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose(1, 2).contiguous().view(nbatchs, -1, self.h * self.d_k)
return self.linears[-1](x)
class Feedforward(nn.Module):
def __init__(self,d_model,d_ff,dropout=0.1):
super(Feedforward,self).__init__()
self.w_1=nn.Linear(d_model,d_ff)
self.w_2=nn.Linear(d_ff,d_model)
self.dropout=nn.Dropout(dropout)
def forward(self,x):
return self.w_2(self.dropout(F.relu((self.w_1(x)))))
class LayerNorm(nn.Module):
def __init__(self,features,eps=1e-6):
super(LayerNorm,self).__init__()
self.a_2=nn.Parameter(torch.ones(features))
self.b_2=nn.Parameter(torch.zeros(features))
self.eps=eps
def forward(self,x):
mean=x.mean(-1,keepdim=True)
std=x.std(-1,keepdim=True)
return self.a_2*(x-mean)/(std+self.eps)+self.b_2
class Generator(nn.Module):
def __init__(self,d_model,vocab):
super(Generator,self).__init__()
self.proj=nn.Linear(d_model,vocab)
def forward(self,x):
return F.log_softmax(self.proj(x),dim=-1)
# encoderLayer clone numbers times of enc_depth.
# 把encoderLayer重复enc_depth次;
class Encoder(nn.Module):
def __init__(self, layer, N):
'''N encoder layers '''
super(Encoder,self).__init__()
self.layers = clone(layer, N)
self.norm = LayerNorm(layer.size)
def forward(self,x,mask=None):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class SublayerConnection(nn.Module):
'''LayerNorm +subLayer+dropout+residual connection'''
def __init__(self,size,dropout):
super(SublayerConnection,self).__init__()
self.norm=LayerNorm(size)
self.dropout=nn.Dropout(dropout)
def forward(self,x,sublayer):
return x+self.dropout(sublayer(self.norm(x)))
class EncoderLayer(nn.Module):
def __init__(self,size,self_attn,feed_forward,dropout):
'''size is the embedding dimension'''
super(EncoderLayer,self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.sublayer = clone(SublayerConnection(size,dropout),2)
self.size = size
def forward(self,x,mask=None):
x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x,mask))
return self.sublayer[1](x, self.feed_forward)
class Decoder(nn.Module):
def __init__(self,layer,N):
super(Decoder,self).__init__()
self.layers = clone(layer,N)
self.norm = LayerNorm(layer.size)
def forward(self,x, memory,src_mask=None,tgt_mask=None):
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class DecoderLayer(nn.Module):
def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
super(DecoderLayer,self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clone(SublayerConnection(size,dropout),3)
def forward(self,x,memory,src_mask=None,tgt_mask=None):
m = memory
x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,tgt_mask))
x = self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask))
return self.sublayer[2](x,self.feed_forward)
class CrossAttLayer(nn.Module):
def __init__(self,d_model,self_attn,feed_forward,dropout=0.1):
super(CrossAttLayer, self).__init__()
self.size = d_model
self.self_attn = self_attn
# self.self_attn_0 = copy.deepcopy(self_attn)
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout)
self.sublayer = clone(SublayerConnection(d_model, dropout), 2)
# self.sublayer = clone(SublayerConnection(d_model,dropout),3) # 可以改成三层的,第一层是self_attn
def forward(self,q,k,v,src_mask=None):
# k = self.sublayer[0](k, lambda k: self.self_attn_0(k,k,k))
# q = self.sublayer[0](q, lambda q: self.self_attn_0(q,q,q))
# x = self.sublayer[1](q, lambda q: self.self_attn(q,k,k,src_mask))
# x = self.sublayer[2](x, self.feed_forward)
x = self.sublayer[0](q, lambda q: self.self_attn(q,k,v,src_mask))
x = self.sublayer[1](x, self.feed_forward)
return x
class CrossAtt(nn.Module):
def __init__(self, crossAttlayer, N=1):
super(CrossAtt, self).__init__()
self.layers = clone(crossAttlayer,N)
self.norm = LayerNorm(crossAttlayer.size)
def forward(self, q, k, v, src_mask=None):
for crossAttnLayer in self.layers:
q = crossAttnLayer(q, k, v, src_mask)
return self.norm(q)
class Transformer(nn.Module):
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
super(Transformer,self).__init__()
c = copy.deepcopy
attn = MultiHeadedAttention(heads,d_model)
ff = Feedforward(d_model,d_ff,dropout)
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
#self.register_buffer('src_mask', src_mask, persistent=False)
#self.register_buffer('tgt_mask', tgt_mask, persistent=False)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self,src_embeded,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decode(self.encode(src_embeded,src_mask),tgt_embeded,src_mask,tgt_mask)
def encode(self,src_embeded,src_mask=None):
return self.encoder(src_embeded,src_mask)
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
class ModelOne(nn.Module):
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
super(ModelOne,self).__init__()
c = copy.deepcopy
attn = MultiHeadedAttention(heads,d_model)
ff = Feedforward(d_model,d_ff,dropout)
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
#self.register_buffer('src_mask', src_mask, persistent=False)
#self.register_buffer('tgt_mask', tgt_mask, persistent=False)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
# x = self.CrossAtt(q, img_embed, img_embed)
# x2 = self.CrossAtt(q, des_embed, des_embed)
des_embed_self = self.CrossAtt(des_embed, des_embed, des_embed)
x3 = self.CrossAtt(img_embed, des_embed_self, des_embed_self)
# src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
src_embeded = torch.cat((x3, obj_embed), dim=1)
x = self.encode(src_embeded,src_mask)
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
return x
def encode(self,src_embeded,src_mask=None):
return self.encoder(src_embeded,src_mask)
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
class Model005(nn.Module):
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
super(Model005,self).__init__()
c = copy.deepcopy
attn = MultiHeadedAttention(heads,d_model)
ff = Feedforward(d_model,d_ff,dropout)
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
x = self.CrossAtt(q, img_embed, img_embed)
src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
x = self.encode(src_embeded,src_mask)
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
return x
def encode(self,src_embeded,src_mask=None):
return self.encoder(src_embeded,src_mask)
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
class Model006(nn.Module):
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
super(Model006,self).__init__()
c = copy.deepcopy
attn = MultiHeadedAttention(heads,d_model)
ff = Feedforward(d_model,d_ff,dropout)
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, q, k, v, tgt_embeded, des_embed, obj_embed, img_embed, src_mask=None, tgt_mask=None):
x = self.CrossAtt(img_embed, img_embed, img_embed)
x = self.CrossAtt(obj_embed, x, x)
src_embeded = torch.cat((x, des_embed, obj_embed), dim=1)
x = self.encode(src_embeded,src_mask)
x = self.decode(x, tgt_embeded,src_mask, tgt_mask)
return x
def encode(self,src_embeded,src_mask=None):
return self.encoder(src_embeded,src_mask)
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)
class ModelAttn(nn.Module):
def __init__(self,d_model=512,heads=8,enc_depth=8,dec_depth=8,d_ff=1024,dropout=0.1):
super(ModelAttn,self).__init__()
c = copy.deepcopy
attn = MultiHeadedAttention(heads,d_model)
ff = Feedforward(d_model,d_ff,dropout)
self.CrossAtt = CrossAtt(CrossAttLayer(d_model,c(attn),c(ff),dropout),N=1)
self.encoder = Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),enc_depth)
self.decoder = Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),dec_depth)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, q, k, v):
x = self.CrossAtt(q, k, v)
return x
def encode(self,src_embeded,src_mask=None):
return self.encoder(src_embeded,src_mask)
def decode(self,memory,tgt_embeded,src_mask=None,tgt_mask=None):
return self.decoder(tgt_embeded,memory,src_mask,tgt_mask)