| """ |
| Data augmentations for SSL pretraining (SimCLR, DINO). |
| """ |
| import torch |
| import torch.nn.functional as F |
| from typing import Tuple |
|
|
|
|
| @torch.no_grad() |
| def random_time_crop( |
| x: torch.Tensor, |
| ratio: Tuple[float, float] | float = (0.6, 0.9), |
| *, |
| resize_back: bool = True, |
| align_to: int | None = 40 |
| ) -> torch.Tensor: |
| """ |
| Randomly crop a contiguous sub-sequence per sample, optionally resize back to original T. |
| |
| Args: |
| x: (B, C, T) |
| ratio: crop length ratio in [low, high] or a float |
| resize_back: if True, linearly interpolate the cropped view back to length T |
| align_to: if not None, crop length is rounded to a multiple of align_to (>= align_to) |
| """ |
| assert x.dim() == 3, f"expected (B,C,T), got {tuple(x.shape)}" |
| B, C, T = x.shape |
| dev = x.device |
|
|
| def _sample_L() -> int: |
| if isinstance(ratio, (tuple, list)): |
| a, b = float(ratio[0]), float(ratio[1]) |
| r = torch.empty((), device=dev).uniform_(a, b).item() |
| else: |
| r = float(ratio) |
| L = max(2, int(round(T * r))) |
| if align_to and align_to > 1: |
| L = max(align_to, int(round(L / align_to)) * align_to) |
| return min(L, T) |
|
|
| Ls = [_sample_L() for _ in range(B)] |
| outs = [] |
| for b in range(B): |
| L = Ls[b] |
| max_start = max(0, T - L) |
| s = int(torch.randint(0, max_start + 1, (1,), device=dev).item()) |
| v = x[b, :, s:s+L] |
| if resize_back and v.shape[-1] != T: |
| v = F.interpolate(v[None], size=T, mode="linear", align_corners=False)[0] |
| outs.append(v) |
| return torch.stack(outs, dim=0) |
|
|
|
|
| @torch.no_grad() |
| def channel_dropout( |
| x: torch.Tensor, |
| drop_prob: float = 0.2, |
| min_keep: int = 1 |
| ) -> torch.Tensor: |
| """ |
| Drop entire channels to zero with probability drop_prob (per sample, per channel). |
| Ensures at least `min_keep` channels remain active in each sample. |
| |
| Args: |
| x: (B, C, T) |
| drop_prob: probability to drop each channel |
| min_keep: minimum number of channels to keep per sample |
| """ |
| assert x.dim() == 3 |
| B, C, T = x.shape |
| mask = (torch.rand(B, C, 1, device=x.device, dtype=x.dtype) > drop_prob).to(x.dtype) |
| |
| |
| keep = mask.sum(dim=1, keepdim=True) |
| need = (keep < min_keep).squeeze(-1).squeeze(-1) |
| if need.any(): |
| for b in torch.where(need)[0]: |
| idx = torch.randperm(C, device=x.device)[:min_keep] |
| mask[b, idx, 0] = 1.0 |
| |
| return x * mask |
|
|