import torch class Normalizer(torch.nn.Module): def forward(self, pixels: torch.Tensor) -> torch.Tensor: pixels = pixels.float() return pixels / 255.0