Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
class ScaledL2Norm(nn.Module): | |
def __init__(self, in_channels, initial_scale): | |
super(ScaledL2Norm, self).__init__() | |
self.in_channels = in_channels | |
self.scale = nn.Parameter(torch.Tensor(in_channels)) | |
self.initial_scale = initial_scale | |
self.reset_parameters() | |
def forward(self, x): | |
return (F.normalize(x, p=2, dim=1) | |
* self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3)) | |
def reset_parameters(self): | |
self.scale.data.fill_(self.initial_scale) |