Spaces:
Sleeping
Sleeping
File size: 8,219 Bytes
eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.smb.level import MarioLevel
from performer_pytorch import SelfAttention
class CustomUpsample(nn.Module):
def __init__(self, in_channels, out_channels, target_size):
super(CustomUpsample, self).__init__()
if target_size == 4: # Upsampling from 2x2 to 4x4
stride, kernel_size, padding = 1, 3, 0
elif target_size == 7: # Upsampling from 4x4 to 7x7
stride, kernel_size, padding = 2, 3, 1
elif target_size == 8: # Upsampling from 4x4 to 8x8
stride, kernel_size, padding = 2, 4, 1
elif target_size == 14: # Upsampling from 7x7 to 14x14
stride, kernel_size, padding = 2, 2, 0
elif target_size == 16: # Upsampling from 8x8 to 16x16
stride, kernel_size, padding = 2, 4, 1
else:
raise ValueError("Invalid target_size specified.")
self.upsample = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x):
return self.upsample(x)
class PerformerSelfAttention(nn.Module):
def __init__(self, channels, size, n_heads=4):
super(PerformerSelfAttention, self).__init__()
self.channels = channels
self.size = size
self.n_heads = n_heads
# Use PerformerSelfAttention from performer-pytorch
self.performer_attention = SelfAttention(
dim=channels,
heads=self.n_heads
)
self.ln = nn.LayerNorm([channels])
self.ff_self = nn.Sequential(
nn.LayerNorm([channels]),
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels),
)
def forward(self, x):
batch_size = x.size(0)
x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
x_ln = self.ln(x)
# Adapt the input shape for PerformerSelfAttention
query_key_value = x_ln.view(batch_size, self.size * self.size, self.channels)
attention_value = self.performer_attention(query_key_value)
attention_value = attention_value + x
attention_value = self.ff_self(attention_value) + attention_value
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
super().__init__()
self.residual = residual
if not mid_channels:
mid_channels = out_channels
# Define the two convolution layers
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
self.group_norm1 = nn.GroupNorm(1, mid_channels)
self.gelu1 = nn.GELU()
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.group_norm2 = nn.GroupNorm(1, out_channels)
def forward(self, x):
# Apply the first convolution layer
x1 = self.conv1(x)
x1 = self.group_norm1(x1)
x1 = self.gelu1(x1)
# Apply the second convolution layer
x2 = self.conv2(x1)
x2 = self.group_norm2(x2)
# Apply residual connection and GELU activation
if self.residual:
return F.gelu(x + x2)
else:
return x2
class Down(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=32):
super().__init__()
# Max pooling followed by two DoubleConv layers
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2, ceil_mode=True),
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels),
)
# Embedding layer to incorporate time information
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, t):
x = self.maxpool_conv(x)
# Apply the embedding layer and broadcast the output to match spatial dimensions
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class Up(nn.Module):
def __init__(self, in_channels, out_channels, emb_dim=32, target_size=7):
super().__init__()
self.up = CustomUpsample(in_channels=int(in_channels/2), out_channels=int(in_channels/2), target_size=target_size)
# DoubleConv layers after concatenation
self.conv = nn.Sequential(
DoubleConv(in_channels, in_channels, residual=True),
DoubleConv(in_channels, out_channels, in_channels // 2),
)
# Embedding layer to incorporate time information
self.emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_dim,
out_channels
),
)
def forward(self, x, skip_x, t):
# Upsample the input tensor
x = self.up(x)
# Concatenate the upsampled tensor with the skip tensor from the encoder
# print('x.shape: {}, skip_x.shape: {}'.format(x.shape, skip_x.shape))
x = torch.cat([skip_x, x], dim=1)
x = self.conv(x)
# Apply the embedding layer and broadcast the output to match spatial dimensions
emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
return x + emb
class UNet(nn.Module):
def __init__(self, c_in=MarioLevel.n_types, c_out=MarioLevel.n_types, time_dim=32, device="cuda"):
super().__init__()
self.device = device
self.time_dim = time_dim
self.inc = DoubleConv(c_in, 64) # 64x16x16
self.down1 = Down(64, 128) # 128x8x8
self.sa1 = PerformerSelfAttention(128, 8) # 128x8x8
self.down2 = Down(128, 256) # 256x4x4
self.sa2 = PerformerSelfAttention(256, 4) # 256x4x4
self.down3 = Down(256, 256) # 256x2x2
self.sa3 = PerformerSelfAttention(256, 2) # 256x2x2
self.bot1 = DoubleConv(256, 512)
self.bot2 = DoubleConv(512, 512)
self.bot3 = DoubleConv(512, 256) # 256x2x2
self.up1 = Up(512, 128, target_size=4) # 128x4x4
self.sa4 = PerformerSelfAttention(128, 4) # 256x4x4
self.up2 = Up(256, 64, target_size=8) # 64x8x8
self.sa5 = PerformerSelfAttention(64, 8) # 128x8x8
self.up3 = Up(128, 64, target_size=16) # 64x16x16
self.sa6 = PerformerSelfAttention(64, 16) # 64x16x16
self.outc = nn.Conv2d(64, c_out, kernel_size=1) # 11x16x16
def pos_encoding(self, t, channels):
inv_freq = 1.0 / (
10000
** (torch.arange(0, channels, 2, device=self.device).float() / channels)
)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc
def forward(self, x, t):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)
x1 = self.inc(x) # 64x16x16
x2 = self.down1(x1, t) # 128x8x8
# try:
x2 = self.sa1(x2) # 128x8x8
# except RuntimeError:
# print(x.shape, x2.shape)
x3 = self.down2(x2, t) # 256x4x4
x3 = self.sa2(x3) # 256x4x4
x4 = self.down3(x3, t) # 256x2x2
x4 = self.sa3(x4) # 256x2x2
x4 = self.bot1(x4)
x4 = self.bot2(x4)
x4 = self.bot3(x4) # 256x2x2
x = self.up1(x4, x3, t) # 256x4x4
x = self.sa4(x) # 256x4x4
x = self.up2(x, x2, t) # 128x8x8
x = self.sa5(x) # 128x8x8
x = self.up3(x, x1, t) # 64x16x16
x = self.sa6(x) # 64x16x16
output = self.outc(x) # 11x16x16
return output
|