MesserMMP's picture
Add model code and full model weights
f621d73
import os
import warnings
import click
import lightning.pytorch as pl
import torch
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from pytorchvideo.transforms import Normalize, Permute, RandAugment
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.transforms import transforms as T
from torchvision.transforms._transforms_video import ToTensorVideo
from torchvision.transforms import InterpolationMode
from backbone.dataset import SyntaxDataset
from backbone.pl_model import SyntaxLightningModule
# Отключаем предупреждение Lightning о device id при DDP-инициализации
warnings.filterwarnings("ignore", message="No device id is provided via `init_process_group`")
# Устанавливаем точность матричных умножений (оптимизация производительности)
torch.set_float32_matmul_precision("medium")
def get_transforms(video_size, imagenet_mean, imagenet_std, train: bool = True):
"""
Создаёт пайплайн аугментаций/преобразований для видео.
Входные данные:
- видео в формате Tensor (T, H, W, C), dtype uint8
Результат:
- Tensor (C, T, H, W), нормализованный, готовый к подаче в 3D-ResNet.
"""
interpolation_choices = [InterpolationMode.BILINEAR, InterpolationMode.BICUBIC]
if train:
return T.Compose([
# Переводит (T, H, W, C) → (C, T, H, W), значения в [0,1]
ToTensorVideo(),
# Меняем порядок осей: (C, T, H, W) → (T, C, H, W) для RandAugment
Permute(dims=[1, 0, 2, 3]),
# Случайные аугментации по времени/пространству
RandAugment(magnitude=10, num_layers=2),
# Случайное горизонтальное отражение
T.RandomHorizontalFlip(),
# Возвращаемся к формату (C, T, H, W)
Permute(dims=[1, 0, 2, 3]),
# Случайный выбор интерполяции для изменения размера
T.RandomChoice([
T.Resize(size=video_size, interpolation=interp, antialias=True)
for interp in interpolation_choices
]),
# Нормализация по статистикам ImageNet
Normalize(mean=imagenet_mean, std=imagenet_std),
])
else:
# Для валидации/инференса используем только приведение к тензору и resize
return T.Compose([
ToTensorVideo(),
T.Resize(size=video_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
Normalize(mean=imagenet_mean, std=imagenet_std),
])
def make_dataloader(dataset, batch_size: int, num_workers: int, use_weighted_sampler: bool):
"""
Создаёт DataLoader с опциональным WeightedRandomSampler.
Если use_weighted_sampler=True:
- семплирование идёт с учётом весов, возвращаемых dataset.get_sample_weights()
- shuffle выключается, так как порядок определяется сэмплером
"""
if use_weighted_sampler:
sample_weights = dataset.get_sample_weights().cpu()
sampler = WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True)
shuffle = False
else:
sampler = None
shuffle = True
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
shuffle=shuffle,
drop_last=True,
pin_memory=True,
persistent_workers=(num_workers > 0),
)
def make_model(num_classes: int, lr: float, weight_decay: float, max_epochs: int, weight_path: str = None):
"""
Конструктор LightningModule для backbone.
num_classes:
количество выходных нейронов (обычно 2: классификация + регрессия).
"""
return SyntaxLightningModule(
num_classes=num_classes,
lr=lr,
weight_decay=weight_decay,
max_epochs=max_epochs,
weight_path=weight_path,
)
def make_callbacks(phase: str):
"""
Создаёт список callback'ов для Trainer:
- мониторинг learning rate
- сохранение чекпоинтов по метрике val_rmse
"""
lr_monitor = LearningRateMonitor(logging_interval="epoch")
checkpoint = ModelCheckpoint(
monitor="val_rmse",
save_top_k=1 if phase == "pre" else 3,
mode="min",
filename="model-{epoch:02d}-{val_rmse:.3f}",
save_last=True,
)
return [lr_monitor, checkpoint]
def make_trainer(max_epochs: int, logdir: str, logger_name: str, devices: list[int], precision: str):
"""
Создаёт объект Trainer с заданными параметрами:
- logdir: путь к директории для логов TensorBoard
- logger_name: имя поддиректории для текущего эксперимента
- devices: количество GPU-устройств
- precision: режим числовой точности (например, "bf16-mixed")
"""
logger = TensorBoardLogger(save_dir=logdir, name=logger_name)
# Если устройств больше одного — используем DDP, иначе оставляем стратегию по умолчанию
strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto"
return pl.Trainer(
max_epochs=max_epochs,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=devices,
strategy=strategy,
precision=precision,
callbacks=[],
log_every_n_steps=10,
logger=logger,
)
@click.command()
@click.option(
"-r",
"--dataset-root",
type=click.Path(exists=True),
default=".",
show_default=True,
help="Корень датасета (JSON и DICOM-пути считаются относительно него).",
)
@click.option("--fold", type=int, default=4, show_default=True, help="Номер фолда.")
@click.option(
"-a",
"--artery",
type=str,
default="right",
show_default=True,
help="Название артерии: left или right.",
)
@click.option(
"-nc",
"--num-classes",
type=int,
default=2,
show_default=True,
help="Число выходных нейронов (обычно 2: clf + reg).",
)
@click.option("-b", "--batch-size", type=int, default=50, show_default=True, help="Размер batch.")
@click.option("-f", "--frames-per-clip", type=int, default=32, show_default=True, help="Число кадров в клипе.")
@click.option(
"-v",
"--video-size",
type=click.Tuple([int, int]),
default=(256, 256),
show_default=True,
help="Размер кадра (H, W).",
)
@click.option("--max-epochs", type=int, default=10, show_default=True, help="Число эпох для full train.")
@click.option("--num-workers", type=int, default=8, show_default=True, help="Число DataLoader workers.")
@click.option(
"--devices",
type=list[int],
multiple=True,
default=[0],
show_default=True,
help="Список GPU id",
)
@click.option("--precision", type=str, default="bf16-mixed", show_default=True, help="Режим точности.")
@click.option(
"--logdir",
type=click.Path(),
default="./logs/backbone",
show_default=True,
help="Каталог для логов и чекпоинтов backbone.",
)
@click.option(
"--use-weighted-sampler",
is_flag=True,
default=False,
show_default=True,
help="Использовать ли WeightedRandomSampler по score-интервалам.",
)
@click.option("--seed", type=int, default=42, show_default=True, help="Сид для воспроизводимости.")
def main(
dataset_root,
fold,
artery,
num_classes,
batch_size,
frames_per_clip,
video_size,
max_epochs,
num_workers,
devices,
precision,
logdir,
use_weighted_sampler,
seed,
):
"""
Точка входа для обучения backbone-модели.
Последовательность:
1) pretrain: обучение только финального слоя fc
2) full train: дообучение всей модели с началом из последнего чекпоинта pretrain.
"""
# Фиксация сида во всех поддерживаемых библиотеках
pl.seed_everything(seed)
artery = artery.lower()
artery_bin = {"left": 0, "right": 1}.get(artery)
if artery_bin is None:
raise ValueError(f"Unknown artery '{artery}', expected 'left' or 'right'")
# Статистики ImageNet для нормализации входа
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
# Пути к JSON-метаданным фолдов относительно dataset_root
train_meta = f"folds/step2_fold{fold:02d}_train.json"
eval_meta = f"folds/step2_fold{fold:02d}_eval.json"
# Инициализация тренировочного датасета
train_set = SyntaxDataset(
root=dataset_root,
meta=train_meta,
train=True,
length=frames_per_clip,
label=f"syntax_{artery}",
artery_bin=artery_bin,
validation=False,
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=True),
)
# Инициализация валидационного датасета
val_set = SyntaxDataset(
root=dataset_root,
meta=eval_meta,
train=False,
length=frames_per_clip,
label=f"syntax_{artery}",
artery_bin=artery_bin,
validation=True,
transform=get_transforms(video_size, imagenet_mean, imagenet_std, train=False),
)
# DataLoader'ы: для pretrain можно брать увеличенный batch
train_loader_pre = make_dataloader(train_set, batch_size * 2, num_workers, use_weighted_sampler)
train_loader_post = make_dataloader(train_set, batch_size, num_workers, use_weighted_sampler)
val_loader = make_dataloader(val_set, 1, num_workers, use_weighted_sampler=False)
# Получаем форму видео (C, T, H, W) из одного batch для информации
x, *_ = next(iter(train_loader_pre))
video_shape = x.shape[1:]
print(f"Backbone input video shape: {video_shape}")
# Callback'и для pretrain и full train
callbacks_pre = make_callbacks(phase="pre")
callbacks_full = make_callbacks(phase="full")
# ------------------- Pretrain (fc only) -------------------
num_pre_epochs = 10
model_pre = make_model(
num_classes=num_classes,
lr=3e-4,
weight_decay=0.01,
max_epochs=num_pre_epochs,
weight_path=None,
)
trainer_pre = make_trainer(
max_epochs=num_pre_epochs,
logdir=logdir,
logger_name=f"{artery}BinSyntax_R3D_pre_fold{fold:02d}",
devices=devices,
precision=precision,
)
trainer_pre.callbacks.extend(callbacks_pre)
trainer_pre.fit(model_pre, train_loader_pre, val_loader)
# ------------------- Full train (finetune) -------------------
model_full = make_model(
num_classes=num_classes,
lr=1e-4,
weight_decay=0.01,
max_epochs=max_epochs,
weight_path=trainer_pre.checkpoint_callback.last_model_path,
)
trainer_full = make_trainer(
max_epochs=max_epochs,
logdir=logdir,
logger_name=f"{artery}BinSyntax_R3D_full_fold{fold:02d}",
devices=devices,
precision=precision,
)
trainer_full.callbacks.extend(callbacks_full)
trainer_full.fit(model_full, train_loader_post, val_loader)
if __name__ == "__main__":
main()