Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/modules
/fp32_group_norm.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Layer norm done in fp32 (for fp16 training) | |
""" | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Fp32GroupNorm(nn.GroupNorm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def forward(self, input): | |
output = F.group_norm( | |
input.float(), | |
self.num_groups, | |
self.weight.float() if self.weight is not None else None, | |
self.bias.float() if self.bias is not None else None, | |
self.eps, | |
) | |
return output.type_as(input) | |