ziqima's picture
initial commit
4893ce0
raw
history blame contribute delete
311 Bytes
import torch
import torch.nn as nn
torch.nn.LayerNorm
class LayerNorm1d(nn.BatchNorm1d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (
super()
.forward(input.transpose(1, 2).contiguous())
.transpose(1, 2)
.contiguous()
)