Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. | |
| """ | |
| Training script for theia, also called robot visual foundation model (RVFM) in | |
| the code. | |
| This training script uses hydra. To change configurations go for theia/configs. | |
| """ | |
| import math | |
| import os.path as osp | |
| import random | |
| import warnings | |
| from typing import Any, Callable | |
| import hydra | |
| import wandb | |
| import torch | |
| import torch.nn as nn | |
| import torch.distributed as dist | |
| from torch.optim.lr_scheduler import LRScheduler | |
| from torchvision.transforms.v2 import Compose | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from tqdm import tqdm | |
| from omegaconf import DictConfig, OmegaConf | |
| from theia.models.rvfm import RobotVisionFM | |
| from theia.optimizers.utils import param_groups_weight_decay | |
| from theia.utils.logging import create_meters, log_metrics | |
| from theia.utils.seed import seed_everything | |
| from theia.foundation_models.common import MODEL_FEATURE_SIZES, get_model_feature_size | |
| from theia.dataset.data_utils import get_frame_dataloader, get_frame_iterator, get_image_video_dataset | |
| from theia.dataset.oxe.oxe_transforms import totensor | |
| warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") | |
| def train( | |
| rvfm: nn.Module, | |
| target_model_names: list[str], | |
| optimizer: torch.optim.Optimizer, | |
| lr_scheduler: LRScheduler, | |
| train_dataset: Any, | |
| eval_dataset: Any, | |
| cfg: DictConfig, | |
| device: int = 0, | |
| train_epoch_steps: int = 0, | |
| eval_epoch_steps: int = 0, | |
| total_train_steps: int = 0, | |
| warmup_steps: int = 0, | |
| ) -> None: | |
| """Training and evaluation for robot visual foundation model (rvfm). | |
| Args: | |
| rvfm (nn.Module): model to train. | |
| target_model_names (list[str]): list of teacher model names. | |
| optimizer (torch.optim.Optimizer): optimizer. | |
| lr_scheduler (LRScheduler): learning rate scheduler. | |
| train_dataset (Any): train dataset. | |
| eval_dataset (Any): eval dataset. | |
| cfg (DictConfig): train config | |
| device (int, optional): device (of this process). Defaults to 0. | |
| train_epoch_steps (int, optional): steps per training epoch. Defaults to 0. | |
| eval_epoch_steps (int, optional): steps per eval epoch. Defaults to 0. | |
| total_train_steps (int, optional): total training steps. Defaults to 0. | |
| warmup_steps (int, optional): warmup steps. Defaults to 0. | |
| """ | |
| epochs = cfg.training.epochs | |
| steps = 0 | |
| # wrap the loaders so handle sync dataloaders easily | |
| for ep in range(epochs): | |
| train_loaders = get_frame_dataloader( | |
| train_dataset, | |
| batch_size=cfg.training.batch_size, | |
| pin_memory=True, | |
| num_workers=cfg.training.num_workers, | |
| shuffle=cfg.dataset.shuffle, | |
| shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, | |
| seed=cfg.seed + device * 100 + ep, # either cfg.seed or cfg.seed + rank | |
| ) | |
| eval_loaders = get_frame_dataloader( | |
| eval_dataset, | |
| batch_size=cfg.training.batch_size, | |
| pin_memory=True, | |
| num_workers=cfg.training.num_workers, | |
| shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, | |
| seed=cfg.seed, # either cfg.seed or cfg.seed + rank | |
| ) | |
| train_iter = get_frame_iterator(train_loaders) | |
| metric_meters = create_meters(target_model_names) | |
| rvfm.train() | |
| train_tqdm = tqdm(range(train_epoch_steps), ncols=80) if device == 0 else range(train_epoch_steps) | |
| for _ in train_tqdm: | |
| try: | |
| batch = next(train_iter) | |
| except StopIteration: | |
| train_iter = get_frame_iterator(train_loaders) | |
| batch = next(train_iter) | |
| images_batch = batch["image"].to(device, non_blocking=True) | |
| if cfg.training.random_target_models > 0: | |
| batch_target_model_names = random.sample(target_model_names, 2) | |
| else: | |
| batch_target_model_names = target_model_names | |
| target_features_batch = {} | |
| for t in batch_target_model_names: | |
| base_name = t.replace("_cls", "") | |
| cls = True if "_cls" in t else False | |
| if cls: | |
| target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float() | |
| else: | |
| target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float() | |
| pred = rvfm(images_batch) | |
| losses = rvfm.module.get_loss(pred, target_features_batch) | |
| if cfg.training.main_loss == "mse" or cfg.training.main_loss is None: | |
| main_loss = losses["mse_loss"] | |
| elif cfg.training.main_loss == "cos_l1": | |
| main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"] | |
| optimizer.zero_grad() | |
| main_loss.backward() | |
| if cfg.training.grad_clip: | |
| nn.utils.clip_grad_norm_( | |
| rvfm.parameters(), | |
| cfg.training.grad_clip_norm_warmup if steps < warmup_steps else cfg.training.grad_clip_norm, | |
| ) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| steps += 1 | |
| batch_size = images_batch.size(0) | |
| log_metrics( | |
| metric_meters, | |
| target_model_names=target_model_names, | |
| device=device, | |
| batch_size=batch_size, | |
| mode="train", | |
| upload_wandb=True, | |
| main_loss=main_loss, | |
| **losses, | |
| ) | |
| if cfg.training.freeze_translator: | |
| if steps == int(cfg.training.freeze_translator_start_steps_ratio * total_train_steps): | |
| rvfm.module.freeze_translator() | |
| if steps % cfg.logging.save_ckpt_interval == 0 and device == 0: | |
| model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth" | |
| save_path = osp.join(cfg.logging.model_path, model_save_fn) | |
| torch.save(rvfm.module.state_dict(), save_path) | |
| dist.barrier() | |
| rvfm.eval() | |
| eval_iter = get_frame_iterator(eval_loaders) | |
| eval_tqdm = tqdm(range(eval_epoch_steps), ncols=80) if device == 0 else range(eval_epoch_steps) | |
| with torch.no_grad(): | |
| for _ in eval_tqdm: | |
| batch = next(eval_iter) | |
| images_batch = batch["image"] | |
| target_features_batch = {} | |
| for t in target_model_names: | |
| base_name = t.replace("_cls", "") | |
| cls = True if "_cls" in t else False | |
| if cls: | |
| target_features_batch[t] = batch[base_name]["cls"].to(device, non_blocking=True).float() | |
| else: | |
| target_features_batch[t] = batch[base_name]["embedding"].to(device, non_blocking=True).float() | |
| pred = rvfm(images_batch) | |
| losses = rvfm.module.get_loss(pred, target_features_batch) | |
| if cfg.training.main_loss == "mse" or cfg.training.main_loss is None: | |
| main_loss = losses["mse_loss"] | |
| elif cfg.training.main_loss == "cos_l1": | |
| main_loss = 0.9 * losses["cos_loss"] + 0.1 * losses["l1_loss"] | |
| batch_size = images_batch.size(0) | |
| log_metrics( | |
| metric_meters, | |
| target_model_names=target_model_names, | |
| device=device, | |
| batch_size=batch_size, | |
| mode="eval", | |
| upload_wandb=False, | |
| main_loss=main_loss, | |
| **losses, | |
| ) | |
| log_metrics( | |
| metric_meters, | |
| mode="eval", | |
| upload_wandb=True, | |
| only_upload=True, | |
| target_model_names=target_model_names, | |
| device=device, | |
| ) | |
| if device == 0: | |
| model_save_fn = f"{cfg.logging.run_identifier_prefix}_step{steps:08d}.pth" | |
| save_path = osp.join(cfg.logging.model_path, model_save_fn) | |
| torch.save(rvfm.module.state_dict(), save_path) | |
| dist.barrier() | |
| def ddp_setup() -> None: | |
| """Initialize stuff for DDP.""" | |
| dist.init_process_group("nccl") | |
| def ddp_cleanup() -> None: | |
| """Clean up stuff for DDP.""" | |
| dist.destroy_process_group() | |
| def ddp_main(cfg: DictConfig) -> None: | |
| """Entry point of DDP. | |
| Args: | |
| cfg (DictConfig): settings for training. | |
| """ | |
| ddp_setup() | |
| rank, world_size = dist.get_rank(), dist.get_world_size() | |
| target_model_names = ( | |
| cfg.training.target_models.target_model_names | |
| if len(cfg.training.target_models.target_model_names) > 0 | |
| else list(MODEL_FEATURE_SIZES.keys()) | |
| ) | |
| target_model_names = [t for t in target_model_names if "llava" not in t] # llava is currently not supported | |
| target_feature_sizes = {t: get_model_feature_size(t, keep_spatial=True) for t in target_model_names} | |
| target_model_names_wocls = target_model_names[:] | |
| if hasattr(cfg.training, "distill_cls") and cfg.training.distill_cls == True: | |
| target_model_names_copy = target_model_names[:] | |
| for t in target_model_names: | |
| if "google/vit" in t or "facebook/dino" in t or "openai/clip" in t: | |
| target_feature_sizes[t+"_cls"] = get_model_feature_size(t, keep_spatial=True)[:1] | |
| target_model_names_copy.append(t+"_cls") | |
| target_model_names = target_model_names_copy | |
| rvfm = RobotVisionFM( | |
| translator=cfg.model.translator.type, | |
| translator_kwargs=cfg.model.translator.kwargs, | |
| target_feature_sizes=target_feature_sizes, | |
| target_loss_weights=cfg.training.target_models.target_model_weights, | |
| **cfg.model.backbone, | |
| ) | |
| rvfm.to(rank) | |
| rvfm_ddp = DDP(rvfm, device_ids=[rank], find_unused_parameters=False) | |
| image_transform: Compose | Callable = totensor # currently just ndarray to tensor | |
| train_dataset, train_dataset_expected_length = get_image_video_dataset( | |
| dataset_root=cfg.dataset.dataset_root, | |
| dataset_mix=cfg.dataset.dataset_mix, | |
| split="train", | |
| dataset_ratio=cfg.dataset.dataset_ratio, | |
| feature_models=target_model_names_wocls, | |
| image_transform=image_transform, | |
| feature_norm=cfg.dataset.feature_norm, | |
| rank=rank, | |
| world_size=world_size, | |
| shuffle=cfg.dataset.shuffle, | |
| seed=cfg.seed, | |
| shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, | |
| num_workers=cfg.training.num_workers, | |
| ) | |
| eval_dataset, eval_dataset_expected_length = get_image_video_dataset( | |
| dataset_root=cfg.dataset.dataset_root, | |
| dataset_mix=cfg.dataset.dataset_mix, | |
| split="val", | |
| dataset_ratio=0.1, | |
| feature_models=target_model_names_wocls, | |
| image_transform=image_transform, | |
| feature_norm=cfg.dataset.feature_norm, | |
| rank=rank, | |
| world_size=world_size, | |
| shuffle=False, | |
| seed=cfg.seed, | |
| shuffle_buffer_size=cfg.dataset.shuffle_buffer_size, | |
| num_workers=cfg.training.num_workers, | |
| ) | |
| train_epoch_steps = math.ceil(train_dataset_expected_length / cfg.training.batch_size / world_size) | |
| eval_epoch_steps = math.ceil(eval_dataset_expected_length / cfg.training.batch_size / world_size) | |
| total_train_steps = train_epoch_steps * cfg.training.epochs | |
| rvfm_param_groups = param_groups_weight_decay(rvfm_ddp, cfg.training.weight_decay) | |
| lr = cfg.training.base_lr * ( | |
| (cfg.training.batch_size * world_size) / (cfg.training.base_batch_size * cfg.training.base_world_size) | |
| ) | |
| optimizer = hydra.utils.instantiate(cfg.training.optimizer, rvfm_param_groups, lr=lr) | |
| lr_scheduler = hydra.utils.instantiate( | |
| cfg.training.lr_scheduler, | |
| optimizer=optimizer, | |
| warm_up_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps), | |
| cos_lrs_T_0=int(total_train_steps * (1 - cfg.training.warm_up_steps_ratio)), | |
| ) | |
| if rank == 0: | |
| print(OmegaConf.to_yaml(cfg)) | |
| wandb.init(project=cfg.logging.project, name=cfg.logging.run_identifier_prefix, config=OmegaConf.to_object(cfg)) | |
| train( | |
| rvfm_ddp, | |
| target_model_names, | |
| optimizer, | |
| lr_scheduler, | |
| train_dataset, | |
| eval_dataset, | |
| cfg=cfg, | |
| device=rank, | |
| train_epoch_steps=train_epoch_steps, | |
| eval_epoch_steps=eval_epoch_steps, | |
| total_train_steps=total_train_steps, | |
| warmup_steps=int(cfg.training.warm_up_steps_ratio * total_train_steps), | |
| ) | |
| ddp_cleanup() | |
| def main(cfg: DictConfig) -> None: | |
| """Main. Dealing with arguments and call DDP.""" | |
| backbone_fn = f"_{cfg.model.backbone.backbone.replace('/', '-')}" | |
| notes_fn = f"_{cfg.logging.notes}" if cfg.logging.notes else "" | |
| translator_fn = f"_{cfg.model.translator.type}" | |
| pretrained_fn = "_pretrained" if cfg.model.backbone.pretrained else "" | |
| dp_fn = f"_dp{cfg.dataset.dataset_ratio:.3f}" | |
| cfg.logging.run_identifier_prefix = f"rvfm{dp_fn}{backbone_fn}{translator_fn}{pretrained_fn}{notes_fn}" | |
| seed_everything(cfg.seed) | |
| ddp_main(cfg) | |
| if __name__ == "__main__": | |
| main() | |