Deploy_Restoration / net /Transformer.py
AlexZou's picture
Upload 17 files
7eb6194
raw history blame
No virus
3.31 kB
# -*- coding: utf-8 -*-
# @Author : Lintao Peng
# @File : SGFMT.py
# coding=utf-8
# Design based on the Vit
import torch.nn as nn
from net.IntmdSequential import IntermediateSequential
#实现了自注意力机制,相当于unet的bottleneck层
class SelfAttention(nn.Module):
def __init__(
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
):
super().__init__()
self.num_heads = heads
head_dim = dim // heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(dropout_rate)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout_rate)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class PreNormDrop(nn.Module):
def __init__(self, dim, dropout_rate, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(p=dropout_rate)
self.fn = fn
def forward(self, x):
return self.dropout(self.fn(self.norm(x)))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_rate):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_dim, dim),
nn.Dropout(p=dropout_rate),
)
def forward(self, x):
return self.net(x)
class TransformerModel(nn.Module):
def __init__(
self,
dim, #512
depth, #4
heads, #8
mlp_dim, #4096
dropout_rate=0.1,
attn_dropout_rate=0.1,
):
super().__init__()
layers = []
for _ in range(depth):
layers.extend(
[
Residual(
PreNormDrop(
dim,
dropout_rate,
SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
)
),
Residual(
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
),
]
)
# dim = dim / 2
self.net = IntermediateSequential(*layers)
def forward(self, x):
return self.net(x)