| import os |
| import warnings |
|
|
| import click |
| import lightning.pytorch as pl |
| import torch |
| from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint |
| from lightning.pytorch.loggers import TensorBoardLogger |
| from pytorchvideo.transforms import Normalize, Permute, RandAugment |
| from torch.utils.data import DataLoader, WeightedRandomSampler |
| from torchvision.transforms import transforms as T |
| from torchvision.transforms._transforms_video import ToTensorVideo |
| from torchvision.transforms import InterpolationMode |
|
|
| from backbone.dataset import SyntaxDataset |
| from backbone.pl_model import SyntaxLightningModule |
|
|
|
|
| |
| warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`") |
|
|
| |
| torch.set_float32_matmul_precision("medium") |
|
|
|
|
| def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True): |
| """ |
| Создаёт пайплайн аугментаций/преобразований для видео. |
| |
| Входные данные: |
| - видео в формате Tensor (T, H, W, C), dtype uint8 |
| |
| Результат: |
| - Tensor (C, T, H, W), нормализованный, готовый к подаче в 3D-ResNet. |
| """ |
| interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC] |
|
|
| if train: |
| return T.Compose([ |
| |
| ToTensorVideo(), |
| |
| Permute(dims=[1, 0, 2, 3]), |
| |
| RandAugment(magnitude=10, num_layers=2), |
| |
| T.RandomHorizontalFlip(), |
| |
| Permute(dims=[1, 0, 2, 3]), |
| |
| T.RandomChoice([ |
| T.Resize(size=video_size, interpolation=interp, antialias=True) |
| for interp in interpolation_choices |
| ]), |
| |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
| else: |
| |
| return T.Compose([ |
| ToTensorVideo(), |
| T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
| Normalize(mean=imagenet_mean, std=imagenet_std), |
| ]) |
|
|
|
|
| def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool): |
| """ |
| Создаёт DataLoader с опциональным WeightedRandomSampler. |
| |
| Если use_weighted_sampler=True: |
| - семплирование идёт с учётом весов, возвращаемых dataset.get_sample_weights() |
| - shuffle выключается, так как порядок определяется сэмплером |
| """ |
| if use_weighted_sampler: |
| sample_weights = dataset.get_sample_weights().cpu() |
| sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True) |
| shuffle = False |
| else: |
| sampler = None |
| shuffle = True |
|
|
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| sampler=sampler, |
| shuffle=shuffle, |
| drop_last=True, |
| pin_memory=True, |
| persistent_workers=(num_workers > 0), |
| ) |
|
|
|
|
| def make_model(num_classes: int, lr: float, weight_decay: float, max_epochs: int, weight_path: str = None): |
| """ |
| Конструктор LightningModule для backbone. |
| |
| num_classes: |
| количество выходных нейронов (обычно 2: классификация + регрессия). |
| """ |
| return SyntaxLightningModule( |
| num_classes=num_classes, |
| lr=lr, |
| weight_decay=weight_decay, |
| max_epochs=max_epochs, |
| weight_path=weight_path, |
| ) |
|
|
|
|
| def make_callbacks(phase: str): |
| """ |
| Создаёт список callback'ов для Trainer: |
| - мониторинг learning rate |
| - сохранение чекпоинтов по метрике val_rmse |
| """ |
| lr_monitor = LearningRateMonitor(logging_interval="epoch") |
|
|
| checkpoint = ModelCheckpoint( |
| monitor="val_rmse", |
| save_top_k=1 if phase == "pre" else 3, |
| mode="min", |
| filename="model-{epoch:02d}-{val_rmse:.3f}", |
| save_last=True, |
| ) |
| return [lr_monitor, checkpoint] |
|
|
|
|
| def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str): |
| """ |
| Создаёт объект Trainer с заданными параметрами: |
| - logdir: путь к директории для логов TensorBoard |
| - logger_name: имя поддиректории для текущего эксперимента |
| - devices: количество GPU-устройств |
| - precision: режим числовой точности (например, "bf16-mixed") |
| """ |
| logger = TensorBoardLogger(save_dir=logdir, name=logger_name) |
|
|
| |
| strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto" |
|
|
| return pl.Trainer( |
| max_epochs=max_epochs, |
| accelerator="gpu" if torch.cuda.is_available() else "cpu", |
| devices=devices, |
| strategy=strategy, |
| precision=precision, |
| callbacks=[], |
| log_every_n_steps=10, |
| logger=logger, |
| ) |
|
|
|
|
| @click.command() |
| @click.option( |
| "-r", |
| "--dataset-root", |
| type=click.Path(exists=True), |
| default=".", |
| show_default=True, |
| help="Корень датасета (JSON и DICOM-пути считаются относительно него).", |
| ) |
| @click.option("--fold", type=int, default=4, show_default=True, help="Номер фолда.") |
| @click.option( |
| "-a", |
| "--artery", |
| type=str, |
| default="right", |
| show_default=True, |
| help="Название артерии: left или right.", |
| ) |
| @click.option( |
| "-nc", |
| "--num-classes", |
| type=int, |
| default=2, |
| show_default=True, |
| help="Число выходных нейронов (обычно 2: clf + reg).", |
| ) |
| @click.option("-b", "--batch-size", type=int, default=50, show_default=True, help="Размер batch.") |
| @click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True, help="Число кадров в клипе.") |
| @click.option( |
| "-v", |
| "--video-size", |
| type=click.Tuple([int, int]), |
| default=(256, 256), |
| show_default=True, |
| help="Размер кадра (H, W).", |
| ) |
| @click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох для full train.") |
| @click.option("--num-workers", type=int, default=8, show_default=True, help="Число DataLoader workers.") |
| @click.option( |
| "--devices", |
| type=list[int], |
| multiple=True, |
| default=[0], |
| show_default=True, |
| help="Список GPU id", |
| ) |
| @click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим точности.") |
| @click.option( |
| "--logdir", |
| type=click.Path(), |
| default="./logs/backbone", |
| show_default=True, |
| help="Каталог для логов и чекпоинтов backbone.", |
| ) |
| @click.option( |
| "--use-weighted-sampler", |
| is_flag=True, |
| default=False, |
| show_default=True, |
| help="Использовать ли WeightedRandomSampler по score-интервалам.", |
| ) |
| @click.option("--seed", type=int, default=42, show_default=True, help="Сид для воспроизводимости.") |
| def main( |
| dataset_root, |
| fold, |
| artery, |
| num_classes, |
| batch_size, |
| frames_per_clip, |
| video_size, |
| max_epochs, |
| num_workers, |
| devices, |
| precision, |
| logdir, |
| use_weighted_sampler, |
| seed, |
| ): |
| """ |
| Точка входа для обучения backbone-модели. |
| |
| Последовательность: |
| 1) pretrain: обучение только финального слоя fc |
| 2) full train: дообучение всей модели с началом из последнего чекпоинта pretrain. |
| """ |
| |
| pl.seed_everything(seed) |
|
|
| artery = artery.lower() |
| artery_bin = {"left": 0, "right": 1}.get(artery) |
| if artery_bin is None: |
| raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'") |
|
|
| |
| imagenet_mean = [0.485, 0.456, 0.406] |
| imagenet_std = [0.229, 0.224, 0.225] |
|
|
| |
| train_meta = f"folds/step2_fold{fold:02d}_train.json" |
| eval_meta = f"folds/step2_fold{fold:02d}_eval.json" |
|
|
| |
| train_set = SyntaxDataset( |
| root=dataset_root, |
| meta=train_meta, |
| train=True, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery_bin=artery_bin, |
| validation=False, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True), |
| ) |
|
|
| |
| val_set = SyntaxDataset( |
| root=dataset_root, |
| meta=eval_meta, |
| train=False, |
| length=frames_per_clip, |
| label=f"syntax_{artery}", |
| artery_bin=artery_bin, |
| validation=True, |
| transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False), |
| ) |
|
|
| |
| train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler) |
| train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler) |
| val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False) |
|
|
| |
| x, *_ = next(iter(train_loader_pre)) |
| video_shape = x.shape[1:] |
| print(f"Backbone input video shape: {video_shape}") |
|
|
| |
| callbacks_pre = make_callbacks(phase="pre") |
| callbacks_full = make_callbacks(phase="full") |
|
|
| |
| num_pre_epochs = 10 |
|
|
| model_pre = make_model( |
| num_classes=num_classes, |
| lr=3e-4, |
| weight_decay=0.01, |
| max_epochs=num_pre_epochs, |
| weight_path=None, |
| ) |
|
|
| trainer_pre = make_trainer( |
| max_epochs=num_pre_epochs, |
| logdir=logdir, |
| logger_name=f"{artery}BinSyntax_R3D_pre_fold{fold:02d}", |
| devices=devices, |
| precision=precision, |
| ) |
| trainer_pre.callbacks.extend(callbacks_pre) |
| trainer_pre.fit(model_pre, train_loader_pre, val_loader) |
|
|
| |
| model_full = make_model( |
| num_classes=num_classes, |
| lr=1e-4, |
| weight_decay=0.01, |
| max_epochs=max_epochs, |
| weight_path=trainer_pre.checkpoint_callback.last_model_path, |
| ) |
|
|
| trainer_full = make_trainer( |
| max_epochs=max_epochs, |
| logdir=logdir, |
| logger_name=f"{artery}BinSyntax_R3D_full_fold{fold:02d}", |
| devices=devices, |
| precision=precision, |
| ) |
| trainer_full.callbacks.extend(callbacks_full) |
| trainer_full.fit(model_full, train_loader_post, val_loader) |
|
|
|
|
| if __name__ == "__main__": |
| main() |