|
|
|
|
|
|
|
|
|
|
|
|
|
"""Normalization modules.""" |
|
|
|
import typing as tp |
|
|
|
import einops |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class ConvLayerNorm(nn.LayerNorm): |
|
""" |
|
Convolution-friendly LayerNorm that moves channels to last dimensions |
|
before running the normalization and moves them back to original position right after. |
|
""" |
|
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): |
|
super().__init__(normalized_shape, **kwargs) |
|
|
|
def forward(self, x): |
|
x = einops.rearrange(x, 'b ... t -> b t ...') |
|
x = super().forward(x) |
|
x = einops.rearrange(x, 'b t ... -> b ... t') |
|
return |
|
|