|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, pack, unpack |
|
from .normalize import Normalize |
|
from .ops import nonlinearity, video_to_image |
|
from .conv import CausalConv3d |
|
from .block import Block |
|
|
|
class ResnetBlock2D(Block): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, |
|
dropout): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout) |
|
self.conv2 = torch.nn.Conv2d( |
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
@video_to_image |
|
def forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
x = x + h |
|
return x |
|
|
|
class ResnetBlock3D(Block): |
|
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout) |
|
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) |
|
else: |
|
self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) |
|
|
|
def forward(self, x): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
return x + h |