| from typing import Sequence | |
| import random | |
| from torch import nn | |
| import torchvision.transforms.functional as TF | |
| class TransformsFixRotation(nn.Module): | |
| r""" | |
| Rotate by one of the given angles. | |
| Example: `rotation_transform = MyRotationTransform(angles=[-30, -15, 0, 15, 30])` | |
| """ | |
| def __init__(self, angles): | |
| super(TransformsFixRotation, self).__init__() | |
| if not isinstance(angles, Sequence): | |
| angles = [angles, ] | |
| self.angles = angles | |
| def forward(self, x): | |
| angle = random.choice(self.angles) | |
| return TF.rotate(x, angle) | |
| def __repr__(self) -> str: | |
| return f"{self.__class__.__name__}(angles={self.angles})" | |