theia / theia /scripts /train /train_rvfm.py
Brandon May
Add theia
26791f7
raw
history blame
13.4 kB
# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.
"""
Training script for theia, also called robot visual foundation model (RVFM) in
the code.
This training script uses hydra. To change configurations go for theia/configs.
"""
import math
import os.path as osp
import random
import warnings
from typing import Any, Callable
import hydra
import wandb
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim.lr_scheduler import LRScheduler
from torchvision.transforms.v2 import Compose
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
from theia.models.rvfm import RobotVisionFM
from theia.optimizers.utils import param_groups_weight_decay
from theia.utils.logging import create_meters, log_metrics
from theia.utils.seed import seed_everything
from theia.foundation_models.common import MODEL_FEATURE_SIZES, get_model_feature_size
from theia.dataset.data_utils import get_frame_dataloader, get_frame_iterator, get_image_video_dataset
from theia.dataset.oxe.oxe_transforms import totensor
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
def train(
rvfm: nn.Module,
target_model_names: list[str],
optimizer: torch.optim.Optimizer,
lr_scheduler: LRScheduler,
train_dataset: Any,
eval_dataset: Any,
cfg: DictConfig,
device: int = 0,
train_epoch_steps: int = 0,
eval_epoch_steps: int = 0,
total_train_steps: int = 0,
warmup_steps: int = 0,
) -> None:
"""Training and evaluation for robot visual foundation model (rvfm).
Args:
rvfm (nn.Module): model to train.
target_model_names (list[str]): list of teacher model names.
optimizer (torch.optim.Optimizer): optimizer.
lr_scheduler (LRScheduler): learning rate scheduler.
train_dataset (Any): train dataset.
eval_dataset (Any): eval dataset.
cfg (DictConfig): train config
device (int, optional): device (of this process). Defaults to 0.
train_epoch_steps (int, optional): steps per training epoch. Defaults to 0.
eval_epoch_steps (int, optional): steps per eval epoch. Defaults to 0.
total_train_steps (int, optional): total training steps. Defaults to 0.
warmup_steps (int, optional): warmup steps. Defaults to 0.
"""
epochs = cfg.training.epochs
steps = 0
# wrap the loaders so handle sync dataloaders easily
for ep in range(epochs):
train_loaders = get_frame_dataloader(
train_dataset,
batch_size=cfg.training.batch_size,
pin_memory=True,
num_workers=cfg.training.num_workers,
shuffle=cfg.dataset.shuffle,
shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
seed=cfg.seed + device * 100 + ep, # either cfg.seed or cfg.seed + rank
)
eval_loaders = get_frame_dataloader(
eval_dataset,
batch_size=cfg.training.batch_size,
pin_memory=True,
num_workers=cfg.training.num_workers,
shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
seed=cfg.seed, # either cfg.seed or cfg.seed + rank
)
train_iter = get_frame_iterator(train_loaders)
metric_meters = create_meters(target_model_names)
rvfm.train()
train_tqdm = tqdm(range(train_epoch_steps), ncols=80) if device == 0 else range(train_epoch_steps)
for _ in train_tqdm:
try:
batch = next(train_iter)
except StopIteration:
train_iter = get_frame_iterator(train_loaders)
batch = next(train_iter)
images_batch = batch["image"].to(device, non_blocking=True)
if cfg.training.random_target_models > 0:
batch_target_model_names = random.sample(target_model_names, 2)
else:
batch_target_model_names = target_model_names
target_features_batch = {}
for t in batch_target_model_names:
base_name = t.replace("_cls", "")
cls = True if "_cls" in t else False
if cls:
target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float()
else:
target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float()
pred = rvfm(images_batch)
losses = rvfm.module.get_loss(pred, target_features_batch)
if cfg.training.main_loss == "mse" or cfg.training.main_loss is None:
main_loss = losses["mse_loss"]
elif cfg.training.main_loss == "cos_l1":
main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"]
optimizer.zero_grad()
main_loss.backward()
if cfg.training.grad_clip:
nn.utils.clip_grad_norm_(
rvfm.parameters(),
cfg.training.grad_clip_norm_warmup if steps < warmup_steps else cfg.training.grad_clip_norm,
)
optimizer.step()
lr_scheduler.step()
steps += 1
batch_size = images_batch.size(0)
log_metrics(
metric_meters,
target_model_names=target_model_names,
device=device,
batch_size=batch_size,
mode="train",
upload_wandb=True,
main_loss=main_loss,
**losses,
)
if cfg.training.freeze_translator:
if steps == int(cfg.training.freeze_translator_start_steps_ratio * total_train_steps):
rvfm.module.freeze_translator()
if steps % cfg.logging.save_ckpt_interval == 0 and device == 0:
model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth"
save_path = osp.join(cfg.logging.model_path, model_save_fn)
torch.save(rvfm.module.state_dict(), save_path)
dist.barrier()
rvfm.eval()
eval_iter = get_frame_iterator(eval_loaders)
eval_tqdm = tqdm(range(eval_epoch_steps), ncols=80) if device == 0 else range(eval_epoch_steps)
with torch.no_grad():
for _ in eval_tqdm:
batch = next(eval_iter)
images_batch = batch["image"]
target_features_batch = {}
for t in target_model_names:
base_name = t.replace("_cls", "")
cls = True if "_cls" in t else False
if cls:
target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float()
else:
target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float()
pred = rvfm(images_batch)
losses = rvfm.module.get_loss(pred, target_features_batch)
if cfg.training.main_loss == "mse" or cfg.training.main_loss is None:
main_loss = losses["mse_loss"]
elif cfg.training.main_loss == "cos_l1":
main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"]
batch_size = images_batch.size(0)
log_metrics(
metric_meters,
target_model_names=target_model_names,
device=device,
batch_size=batch_size,
mode="eval",
upload_wandb=False,
main_loss=main_loss,
**losses,
)
log_metrics(
metric_meters,
mode="eval",
upload_wandb=True,
only_upload=True,
target_model_names=target_model_names,
device=device,
)
if device == 0:
model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth"
save_path = osp.join(cfg.logging.model_path, model_save_fn)
torch.save(rvfm.module.state_dict(), save_path)
dist.barrier()
def ddp_setup() -> None:
"""Initialize stuff for DDP."""
dist.init_process_group("nccl")
def ddp_cleanup() -> None:
"""Clean up stuff for DDP."""
dist.destroy_process_group()
def ddp_main(cfg: DictConfig) -> None:
"""Entry point of DDP.
Args:
cfg (DictConfig): settings for training.
"""
ddp_setup()
rank, world_size = dist.get_rank(), dist.get_world_size()
target_model_names = (
cfg.training.target_models.target_model_names
if len(cfg.training.target_models.target_model_names) > 0
else list(MODEL_FEATURE_SIZES.keys())
)
target_model_names = [t for t in target_model_names if "llava" not in t] # llava is currently not supported
target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names}
target_model_names_wocls = target_model_names[:]
if hasattr(cfg.training, "distill_cls") and cfg.training.distill_cls == True:
target_model_names_copy = target_model_names[:]
for t in target_model_names:
if "google/vit" in t or "facebook/dino" in t or "openai/clip" in t:
target_feature_sizes[t+"_cls"] = get_model_feature_size(t, keep_spatial=True)[:1]
target_model_names_copy.append(t+"_cls")
target_model_names = target_model_names_copy
rvfm = RobotVisionFM(
translator=cfg.model.translator.type,
translator_kwargs=cfg.model.translator.kwargs,
target_feature_sizes=target_feature_sizes,
target_loss_weights=cfg.training.target_models.target_model_weights,
**cfg.model.backbone,
)
rvfm.to(rank)
rvfm_ddp = DDP(rvfm, device_ids=[rank], find_unused_parameters=False)
image_transform: Compose | Callable = totensor # currently just ndarray to tensor
train_dataset, train_dataset_expected_length = get_image_video_dataset(
dataset_root=cfg.dataset.dataset_root,
dataset_mix=cfg.dataset.dataset_mix,
split="train",
dataset_ratio=cfg.dataset.dataset_ratio,
feature_models=target_model_names_wocls,
image_transform=image_transform,
feature_norm=cfg.dataset.feature_norm,
rank=rank,
world_size=world_size,
shuffle=cfg.dataset.shuffle,
seed=cfg.seed,
shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
num_workers=cfg.training.num_workers,
)
eval_dataset, eval_dataset_expected_length = get_image_video_dataset(
dataset_root=cfg.dataset.dataset_root,
dataset_mix=cfg.dataset.dataset_mix,
split="val",
dataset_ratio=0.1,
feature_models=target_model_names_wocls,
image_transform=image_transform,
feature_norm=cfg.dataset.feature_norm,
rank=rank,
world_size=world_size,
shuffle=False,
seed=cfg.seed,
shuffle_buffer_size=cfg.dataset.shuffle_buffer_size,
num_workers=cfg.training.num_workers,
)
train_epoch_steps = math.ceil(train_dataset_expected_length / cfg.training.batch_size / world_size)
eval_epoch_steps = math.ceil(eval_dataset_expected_length / cfg.training.batch_size / world_size)
total_train_steps = train_epoch_steps * cfg.training.epochs
rvfm_param_groups = param_groups_weight_decay(rvfm_ddp, cfg.training.weight_decay)
lr = cfg.training.base_lr * (
(cfg.training.batch_size * world_size) / (cfg.training.base_batch_size * cfg.training.base_world_size)
)
optimizer = hydra.utils.instantiate(cfg.training.optimizer, rvfm_param_groups, lr=lr)
lr_scheduler = hydra.utils.instantiate(
cfg.training.lr_scheduler,
optimizer=optimizer,
warm_up_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps),
cos_lrs_T_0=int(total_train_steps * (1 - cfg.training.warm_up_steps_ratio)),
)
if rank == 0:
print(OmegaConf.to_yaml(cfg))
wandb.init(project=cfg.logging.project, name=cfg.logging.run_identifier_prefix, config=OmegaConf.to_object(cfg))
train(
rvfm_ddp,
target_model_names,
optimizer,
lr_scheduler,
train_dataset,
eval_dataset,
cfg=cfg,
device=rank,
train_epoch_steps=train_epoch_steps,
eval_epoch_steps=eval_epoch_steps,
total_train_steps=total_train_steps,
warmup_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps),
)
ddp_cleanup()
@hydra.main(version_base=None, config_path="../../configs", config_name="train_rvfm_imagenet")
def main(cfg: DictConfig) -> None:
"""Main. Dealing with arguments and call DDP."""
backbone_fn = f"_{cfg.model.backbone.backbone.replace('/', '-')}"
notes_fn = f"_{cfg.logging.notes}" if cfg.logging.notes else ""
translator_fn = f"_{cfg.model.translator.type}"
pretrained_fn = "_pretrained" if cfg.model.backbone.pretrained else ""
dp_fn = f"_dp{cfg.dataset.dataset_ratio:.3f}"
cfg.logging.run_identifier_prefix = f"rvfm{dp_fn}{backbone_fn}{translator_fn}{pretrained_fn}{notes_fn}"
seed_everything(cfg.seed)
ddp_main(cfg)
if __name__ == "__main__":
main()