Spaces:
Running
on
Zero
Running
on
Zero
| from abc import abstractmethod | |
| from typing import Optional, Any, Dict | |
| import torch | |
| from modules.NeuralNetwork import transformer | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from modules.Attention import Attention | |
| from modules.cond import cast | |
| from modules.sample import sampling_util | |
| oai_ops = cast.disable_weight_init | |
| class TimestepBlock1(nn.Module): | |
| """#### Abstract class representing a timestep block.""" | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the timestep block. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `emb` (torch.Tensor): The embedding tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| pass | |
| def forward_timestep_embed1( | |
| ts: nn.ModuleList, | |
| x: torch.Tensor, | |
| emb: torch.Tensor, | |
| context: Optional[torch.Tensor] = None, | |
| transformer_options: Optional[Dict[str, Any]] = {}, | |
| output_shape: Optional[torch.Size] = None, | |
| time_context: Optional[torch.Tensor] = None, | |
| num_video_frames: Optional[int] = None, | |
| image_only_indicator: Optional[bool] = None, | |
| ) -> torch.Tensor: | |
| """#### Forward pass for timestep embedding. | |
| #### Args: | |
| - `ts` (nn.ModuleList): The list of timestep blocks. | |
| - `x` (torch.Tensor): The input tensor. | |
| - `emb` (torch.Tensor): The embedding tensor. | |
| - `context` (torch.Tensor, optional): The context tensor. Defaults to None. | |
| - `transformer_options` (dict, optional): The transformer options. Defaults to {}. | |
| - `output_shape` (torch.Size, optional): The output shape. Defaults to None. | |
| - `time_context` (torch.Tensor, optional): The time context tensor. Defaults to None. | |
| - `num_video_frames` (int, optional): The number of video frames. Defaults to None. | |
| - `image_only_indicator` (bool, optional): The image only indicator. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| for layer in ts: | |
| if isinstance(layer, TimestepBlock1): | |
| x = layer(x, emb) | |
| elif isinstance(layer, transformer.SpatialTransformer): | |
| x = layer(x, context, transformer_options) | |
| if "transformer_index" in transformer_options: | |
| transformer_options["transformer_index"] += 1 | |
| elif isinstance(layer, Upsample1): | |
| x = layer(x, output_shape=output_shape) | |
| else: | |
| x = layer(x) | |
| return x | |
| class Upsample1(nn.Module): | |
| """#### Class representing an upsample layer.""" | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool, | |
| dims: int = 2, | |
| out_channels: Optional[int] = None, | |
| padding: int = 1, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| operations: Any = oai_ops, | |
| ): | |
| """#### Initialize the upsample layer. | |
| #### Args: | |
| - `channels` (int): The number of input channels. | |
| - `use_conv` (bool): Whether to use convolution. | |
| - `dims` (int, optional): The number of dimensions. Defaults to 2. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to None. | |
| - `padding` (int, optional): The padding size. Defaults to 1. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `device` (torch.device, optional): The device. Defaults to None. | |
| - `operations` (any, optional): The operations. Defaults to oai_ops. | |
| """ | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.dims = dims | |
| if use_conv: | |
| self.conv = operations.conv_nd( | |
| dims, | |
| self.channels, | |
| self.out_channels, | |
| 3, | |
| padding=padding, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, output_shape: Optional[torch.Size] = None | |
| ) -> torch.Tensor: | |
| """#### Forward pass for the upsample layer. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `output_shape` (torch.Size, optional): The output shape. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| assert x.shape[1] == self.channels | |
| shape = [x.shape[2] * 2, x.shape[3] * 2] | |
| if output_shape is not None: | |
| shape[0] = output_shape[2] | |
| shape[1] = output_shape[3] | |
| x = F.interpolate(x, size=shape, mode="nearest") | |
| if self.use_conv: | |
| x = self.conv(x) | |
| return x | |
| class Downsample1(nn.Module): | |
| """#### Class representing a downsample layer.""" | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool, | |
| dims: int = 2, | |
| out_channels: Optional[int] = None, | |
| padding: int = 1, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| operations: Any = oai_ops, | |
| ): | |
| """#### Initialize the downsample layer. | |
| #### Args: | |
| - `channels` (int): The number of input channels. | |
| - `use_conv` (bool): Whether to use convolution. | |
| - `dims` (int, optional): The number of dimensions. Defaults to 2. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to None. | |
| - `padding` (int, optional): The padding size. Defaults to 1. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `device` (torch.device, optional): The device. Defaults to None. | |
| - `operations` (any, optional): The operations. Defaults to oai_ops. | |
| """ | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.dims = dims | |
| stride = 2 if dims != 3 else (1, 2, 2) | |
| self.op = operations.conv_nd( | |
| dims, | |
| self.channels, | |
| self.out_channels, | |
| 3, | |
| stride=stride, | |
| padding=padding, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the downsample layer. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| assert x.shape[1] == self.channels | |
| return self.op(x) | |
| class ResBlock1(TimestepBlock1): | |
| """#### Class representing a residual block layer.""" | |
| def __init__( | |
| self, | |
| channels: int, | |
| emb_channels: int, | |
| dropout: float, | |
| out_channels: Optional[int] = None, | |
| use_conv: bool = False, | |
| use_scale_shift_norm: bool = False, | |
| dims: int = 2, | |
| use_checkpoint: bool = False, | |
| up: bool = False, | |
| down: bool = False, | |
| kernel_size: int = 3, | |
| exchange_temb_dims: bool = False, | |
| skip_t_emb: bool = False, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| operations: Any = oai_ops, | |
| ): | |
| """#### Initialize the residual block layer. | |
| #### Args: | |
| - `channels` (int): The number of input channels. | |
| - `emb_channels` (int): The number of embedding channels. | |
| - `dropout` (float): The dropout rate. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to None. | |
| - `use_conv` (bool, optional): Whether to use convolution. Defaults to False. | |
| - `use_scale_shift_norm` (bool, optional): Whether to use scale shift normalization. Defaults to False. | |
| - `dims` (int, optional): The number of dimensions. Defaults to 2. | |
| - `use_checkpoint` (bool, optional): Whether to use checkpointing. Defaults to False. | |
| - `up` (bool, optional): Whether to use upsampling. Defaults to False. | |
| - `down` (bool, optional): Whether to use downsampling. Defaults to False. | |
| - `kernel_size` (int, optional): The kernel size. Defaults to 3. | |
| - `exchange_temb_dims` (bool, optional): Whether to exchange embedding dimensions. Defaults to False. | |
| - `skip_t_emb` (bool, optional): Whether to skip embedding. Defaults to False. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `device` (torch.device, optional): The device. Defaults to None. | |
| - `operations` (any, optional): The operations. Defaults to oai_ops. | |
| """ | |
| super().__init__() | |
| self.channels = channels | |
| self.emb_channels = emb_channels | |
| self.dropout = dropout | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.use_checkpoint = use_checkpoint | |
| self.use_scale_shift_norm = use_scale_shift_norm | |
| self.exchange_temb_dims = exchange_temb_dims | |
| padding = kernel_size // 2 | |
| self.in_layers = nn.Sequential( | |
| operations.GroupNorm(32, channels, dtype=dtype, device=device), | |
| nn.SiLU(), | |
| operations.conv_nd( | |
| dims, | |
| channels, | |
| self.out_channels, | |
| kernel_size, | |
| padding=padding, | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| self.updown = up or down | |
| self.h_upd = self.x_upd = nn.Identity() | |
| self.skip_t_emb = skip_t_emb | |
| self.emb_layers = nn.Sequential( | |
| nn.SiLU(), | |
| operations.Linear( | |
| emb_channels, | |
| (2 * self.out_channels if use_scale_shift_norm else self.out_channels), | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| self.out_layers = nn.Sequential( | |
| operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), | |
| nn.SiLU(), | |
| nn.Dropout(p=dropout), | |
| operations.conv_nd( | |
| dims, | |
| self.out_channels, | |
| self.out_channels, | |
| kernel_size, | |
| padding=padding, | |
| dtype=dtype, | |
| device=device, | |
| ), | |
| ) | |
| if self.out_channels == channels: | |
| self.skip_connection = nn.Identity() | |
| else: | |
| self.skip_connection = operations.conv_nd( | |
| dims, channels, self.out_channels, 1, dtype=dtype, device=device | |
| ) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the residual block layer. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `emb` (torch.Tensor): The embedding tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| return sampling_util.checkpoint( | |
| self._forward, (x, emb), self.parameters(), self.use_checkpoint | |
| ) | |
| def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
| """#### Internal forward pass for the residual block layer. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `emb` (torch.Tensor): The embedding tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| h = self.in_layers(x) | |
| emb_out = None | |
| if not self.skip_t_emb: | |
| emb_out = self.emb_layers(emb).type(h.dtype) | |
| while len(emb_out.shape) < len(h.shape): | |
| emb_out = emb_out[..., None] | |
| if emb_out is not None: | |
| h = h + emb_out | |
| h = self.out_layers(h) | |
| return self.skip_connection(x) + h | |
| ops = cast.disable_weight_init | |
| class ResnetBlock(nn.Module): | |
| """#### Class representing a ResNet block layer.""" | |
| def __init__( | |
| self, | |
| *, | |
| in_channels: int, | |
| out_channels: Optional[int] = None, | |
| conv_shortcut: bool = False, | |
| dropout: float, | |
| temb_channels: int = 512, | |
| ): | |
| """#### Initialize the ResNet block layer. | |
| #### Args: | |
| - `in_channels` (int): The number of input channels. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to None. | |
| - `conv_shortcut` (bool, optional): Whether to use convolution shortcut. Defaults to False. | |
| - `dropout` (float): The dropout rate. | |
| - `temb_channels` (int, optional): The number of embedding channels. Defaults to 512. | |
| """ | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.swish = torch.nn.SiLU(inplace=True) | |
| self.norm1 = Attention.Normalize(in_channels) | |
| self.conv1 = ops.Conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| self.norm2 = Attention.Normalize(out_channels) | |
| self.dropout = torch.nn.Dropout(dropout, inplace=True) | |
| self.conv2 = ops.Conv2d( | |
| out_channels, out_channels, kernel_size=3, stride=1, padding=1 | |
| ) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = ops.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
| ) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the ResNet block layer. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `temb` (torch.Tensor): The embedding tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| h = x | |
| h = self.norm1(h) | |
| h = self.swish(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = self.swish(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + h | |