Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
230e1e3
1
Parent(s):
1b3e11b
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +6 -5
- app.py +131 -4
- imaginaire/.DS_Store +0 -0
- imaginaire/__init__.py +14 -0
- imaginaire/__pycache__/__init__.cpython-310.pyc +0 -0
- imaginaire/__pycache__/__init__.cpython-39.pyc +0 -0
- imaginaire/callbacks/__init__.py +14 -0
- imaginaire/callbacks/every_n.py +84 -0
- imaginaire/callbacks/manual_gc.py +49 -0
- imaginaire/config.py +410 -0
- imaginaire/lazy_config/__init__.py +73 -0
- imaginaire/lazy_config/__pycache__/__init__.cpython-310.pyc +0 -0
- imaginaire/lazy_config/__pycache__/file_io.cpython-310.pyc +0 -0
- imaginaire/lazy_config/__pycache__/instantiate.cpython-310.pyc +0 -0
- imaginaire/lazy_config/__pycache__/lazy.cpython-310.pyc +0 -0
- imaginaire/lazy_config/__pycache__/omegaconf_patch.cpython-310.pyc +0 -0
- imaginaire/lazy_config/__pycache__/registry.cpython-310.pyc +0 -0
- imaginaire/lazy_config/file_io.py +24 -0
- imaginaire/lazy_config/instantiate.py +119 -0
- imaginaire/lazy_config/lazy.py +442 -0
- imaginaire/lazy_config/omegaconf_patch.py +65 -0
- imaginaire/lazy_config/registry.py +74 -0
- imaginaire/model.py +137 -0
- imaginaire/trainer.py +322 -0
- imaginaire/utils/.DS_Store +0 -0
- imaginaire/utils/__init__.py +14 -0
- imaginaire/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- imaginaire/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- imaginaire/utils/__pycache__/device.cpython-310.pyc +0 -0
- imaginaire/utils/__pycache__/distributed.cpython-310.pyc +0 -0
- imaginaire/utils/__pycache__/io.cpython-310.pyc +0 -0
- imaginaire/utils/__pycache__/io.cpython-39.pyc +0 -0
- imaginaire/utils/__pycache__/log.cpython-310.pyc +0 -0
- imaginaire/utils/__pycache__/log.cpython-39.pyc +0 -0
- imaginaire/utils/__pycache__/misc.cpython-310.pyc +0 -0
- imaginaire/utils/callback.py +518 -0
- imaginaire/utils/checkpointer.py +282 -0
- imaginaire/utils/config_helper.py +201 -0
- imaginaire/utils/device.py +39 -0
- imaginaire/utils/distributed.py +444 -0
- imaginaire/utils/easy_io/__init__.py +14 -0
- imaginaire/utils/easy_io/__pycache__/__init__.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/__pycache__/easy_io.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/__pycache__/file_client.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/backends/__init__.py +28 -0
- imaginaire/utils/easy_io/backends/__pycache__/__init__.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/backends/__pycache__/base_backend.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/backends/__pycache__/http_backend.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/backends/__pycache__/local_backend.cpython-310.pyc +0 -0
- imaginaire/utils/easy_io/backends/__pycache__/registry_utils.cpython-310.pyc +0 -0
README.md
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
short_description: rCM model for Wan2.1
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: rCM-Wan
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
short_description: rCM model for Wan2.1
|
| 12 |
---
|
| 13 |
|
| 14 |
+
This demo uses the unofficial rCM models for Wan from [worstcoder/rcm-Wan](https://huggingface.co/worstcoder/rcm-Wan).
|
app.py
CHANGED
|
@@ -1,7 +1,134 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
+
import time
|
| 4 |
+
import os
|
| 5 |
+
import requests
|
| 6 |
+
from wan2pt1_t2v_rcm_infer import inference, prepare_models
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
import random
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
import gc
|
| 11 |
+
import torch
|
| 12 |
|
| 13 |
+
import flash_attn
|
| 14 |
+
print("flash_attn version: ", flash_attn.__version__)
|
| 15 |
|
| 16 |
+
dit_path_1p3B = hf_hub_download(
|
| 17 |
+
repo_id="worstcoder/rcm-Wan",
|
| 18 |
+
filename="rCM_Wan2.1_T2V_1.3B_480p.pt",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
dit_path_14B = hf_hub_download(
|
| 22 |
+
repo_id="worstcoder/rcm-Wan",
|
| 23 |
+
filename="rCM_Wan2.1_T2V_14B_480p.pt",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
vae_path = hf_hub_download(
|
| 27 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
| 28 |
+
filename="Wan2.1_VAE.pth"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
text_encoder_path = hf_hub_download(
|
| 32 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
| 33 |
+
filename="models_t5_umt5-xxl-enc-bf16.pth"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
net_1p3B, net_14B, tokenizer, t5_encoder = prepare_models(dit_path_1p3B, dit_path_14B, vae_path, text_encoder_path)
|
| 37 |
+
print("Loaded models")
|
| 38 |
+
gc.collect()
|
| 39 |
+
|
| 40 |
+
def random_seed():
|
| 41 |
+
return random.randint(0, 2**32 - 1)
|
| 42 |
+
|
| 43 |
+
@spaces.GPU(duration=120)
|
| 44 |
+
def generate_videos(prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed):
|
| 45 |
+
if seed is None:
|
| 46 |
+
seed = random.randint(0, 2**32 - 1)
|
| 47 |
+
|
| 48 |
+
args = SimpleNamespace(
|
| 49 |
+
prompt=prompt,
|
| 50 |
+
model_size=model_size,
|
| 51 |
+
num_steps=num_steps,
|
| 52 |
+
num_samples=num_samples,
|
| 53 |
+
sigma_max=sigma_max,
|
| 54 |
+
num_frames=77,
|
| 55 |
+
resolution="480p",
|
| 56 |
+
aspect_ratio=aspect_ratio,
|
| 57 |
+
seed=seed,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
video_list = inference(args, net_1p3B, net_14B, tokenizer, t5_encoder)
|
| 62 |
+
|
| 63 |
+
if aspect_ratio == "16:9":
|
| 64 |
+
return video_list, None
|
| 65 |
+
else:
|
| 66 |
+
return None, video_list
|
| 67 |
+
|
| 68 |
+
def update_num_samples(model_choice):
|
| 69 |
+
if model_choice == "rCM-Wan2.1-T2V-1.3B-480p":
|
| 70 |
+
options = [1, 2, 3, 4]
|
| 71 |
+
else:
|
| 72 |
+
options = [1, 2, 3]
|
| 73 |
+
return gr.Dropdown(choices=options, value=options[0], label="num_samples")
|
| 74 |
+
|
| 75 |
+
with gr.Blocks() as demo:
|
| 76 |
+
gr.Markdown("## rCM model for Wan")
|
| 77 |
+
|
| 78 |
+
examples = [
|
| 79 |
+
["A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."],
|
| 80 |
+
["A close-up shot captures a steaming hot pot brimming with vegetables and dumplings, set on a rustic wooden table. The camera focuses on the bubbling broth as a woman, dressed in a light, patterned blouse, reaches in with chopsticks to lift a tender leaf of cabbage from the simmering mixture. Steam rises around her as she leans back slightly, her warm smile reflecting satisfaction and joy. Her movements are smooth and deliberate, showcasing her comfort and familiarity with the dining process. The background includes a small bowl of dipping sauce and a clay pot, adding to the cozy, communal dining atmosphere."],
|
| 81 |
+
["A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond."]
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
with gr.Row():
|
| 85 |
+
with gr.Column(scale=1):
|
| 86 |
+
with gr.Row():
|
| 87 |
+
prompt = gr.Textbox(label="Text prompt", placeholder="Text prompt for videos")
|
| 88 |
+
model_size = gr.Radio(
|
| 89 |
+
["rCM-Wan2.1-T2V-1.3B-480p", "rCM-Wan2.1-T2V-14B-480p"],
|
| 90 |
+
value="rCM-Wan2.1-T2V-1.3B-480p",
|
| 91 |
+
label="Model"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
with gr.Row():
|
| 95 |
+
num_samples = gr.Dropdown([1, 2, 3, 4], value=1, label="num_samples")
|
| 96 |
+
aspect_ratio = gr.Radio(["16:9", "9:16"], value="16:9", label="aspect_ratio")
|
| 97 |
+
sigma_max = gr.Dropdown([40, 80, 120, 200, 400, 800, 1600], value=80, label="sigma_max")
|
| 98 |
+
|
| 99 |
+
with gr.Row():
|
| 100 |
+
num_steps = gr.Slider(1, 4, value=4, step=1, label="num_steps")
|
| 101 |
+
seed = gr.Number(label="seed", value=random_seed(), interactive=True)
|
| 102 |
+
|
| 103 |
+
with gr.Row():
|
| 104 |
+
regenerate_btn = gr.Button("New Seed")
|
| 105 |
+
run_btn = gr.Button("Generate Videos")
|
| 106 |
+
|
| 107 |
+
with gr.Row():
|
| 108 |
+
gr.Examples(
|
| 109 |
+
examples,
|
| 110 |
+
inputs=[prompt],
|
| 111 |
+
label="Example prompts"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
with gr.Column(scale=1):
|
| 115 |
+
video_16_9 = gr.Video(label="Videos 16:9", width=832)
|
| 116 |
+
video_9_16 = gr.Video(label="Videos 9:16", width=480, visible=False)
|
| 117 |
+
|
| 118 |
+
def show_video(aspect):
|
| 119 |
+
if aspect == "16:9":
|
| 120 |
+
return gr.update(visible=True), gr.update(visible=False, value=None)
|
| 121 |
+
else:
|
| 122 |
+
return gr.update(visible=False, value=None), gr.update(visible=True)
|
| 123 |
+
|
| 124 |
+
model_size.change(fn=update_num_samples, inputs=model_size, outputs=num_samples)
|
| 125 |
+
aspect_ratio.change(show_video, inputs=aspect_ratio, outputs=[video_16_9, video_9_16])
|
| 126 |
+
regenerate_btn.click(fn=random_seed, outputs=seed)
|
| 127 |
+
|
| 128 |
+
run_btn.click(
|
| 129 |
+
fn=generate_videos,
|
| 130 |
+
inputs=[prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed],
|
| 131 |
+
outputs=[video_16_9, video_9_16],
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
demo.launch()
|
imaginaire/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
imaginaire/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (130 Bytes). View file
|
|
|
imaginaire/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
imaginaire/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/callbacks/every_n.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from abc import abstractmethod
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from imaginaire.model import ImaginaireModel
|
| 21 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 22 |
+
from imaginaire.utils import distributed, log
|
| 23 |
+
from imaginaire.utils.callback import Callback
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class EveryN(Callback):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
every_n: int | None = None,
|
| 30 |
+
step_size: int = 1,
|
| 31 |
+
barrier_after_run: bool = True,
|
| 32 |
+
run_at_start: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
"""Constructor for `EveryN`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
every_n (int): Frequency with which callback is run during training.
|
| 38 |
+
step_size (int): Size of iteration step count. Default 1.
|
| 39 |
+
barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts.
|
| 40 |
+
run_at_start (bool): Whether to run at the beginning of training. Default False.
|
| 41 |
+
"""
|
| 42 |
+
self.every_n = every_n
|
| 43 |
+
if self.every_n == 0:
|
| 44 |
+
log.warning(
|
| 45 |
+
f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.step_size = step_size
|
| 49 |
+
self.barrier_after_run = barrier_after_run
|
| 50 |
+
self.run_at_start = run_at_start
|
| 51 |
+
|
| 52 |
+
def on_training_step_end(
|
| 53 |
+
self,
|
| 54 |
+
model: ImaginaireModel,
|
| 55 |
+
data_batch: dict[str, torch.Tensor],
|
| 56 |
+
output_batch: dict[str, torch.Tensor],
|
| 57 |
+
loss: torch.Tensor,
|
| 58 |
+
iteration: int = 0,
|
| 59 |
+
) -> None:
|
| 60 |
+
# every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training
|
| 61 |
+
if self.every_n != 0:
|
| 62 |
+
trainer = self.trainer
|
| 63 |
+
global_step = iteration // self.step_size
|
| 64 |
+
should_run = (iteration == 1 and self.run_at_start) or (
|
| 65 |
+
global_step % self.every_n == 0
|
| 66 |
+
) # (self.every_n - 1)
|
| 67 |
+
if should_run:
|
| 68 |
+
log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}")
|
| 69 |
+
self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration)
|
| 70 |
+
log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}")
|
| 71 |
+
# add necessary barrier to avoid timeout
|
| 72 |
+
if self.barrier_after_run:
|
| 73 |
+
distributed.barrier()
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def every_n_impl(
|
| 77 |
+
self,
|
| 78 |
+
trainer: ImaginaireTrainer,
|
| 79 |
+
model: ImaginaireModel,
|
| 80 |
+
data_batch: dict[str, torch.Tensor],
|
| 81 |
+
output_batch: dict[str, torch.Tensor],
|
| 82 |
+
loss: torch.Tensor,
|
| 83 |
+
iteration: int,
|
| 84 |
+
) -> None: ...
|
imaginaire/callbacks/manual_gc.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
|
| 18 |
+
from imaginaire.callbacks.every_n import EveryN
|
| 19 |
+
from imaginaire.utils import log
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ManualGarbageCollection(EveryN):
|
| 23 |
+
"""
|
| 24 |
+
Disable auto gc and manually trigger garbage collection every N iterations
|
| 25 |
+
It is super useful for large scale training to reduce gpu sync time!
|
| 26 |
+
Can reach 50% speedup.
|
| 27 |
+
|
| 28 |
+
It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses.
|
| 29 |
+
|
| 30 |
+
We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, *args, warm_up: int = 5, **kwargs):
|
| 34 |
+
kwargs["barrier_after_run"] = False
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
self.counter = 0
|
| 38 |
+
self.warm = warm_up
|
| 39 |
+
|
| 40 |
+
def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration):
|
| 41 |
+
del trainer, model, data_batch, output_batch, loss
|
| 42 |
+
self.counter += 1
|
| 43 |
+
if self.counter < self.warm:
|
| 44 |
+
return
|
| 45 |
+
if self.counter == self.warm:
|
| 46 |
+
gc.disable()
|
| 47 |
+
log.critical("Garbage collection disabled")
|
| 48 |
+
|
| 49 |
+
gc.collect(1)
|
imaginaire/config.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Training config system for Imaginare4"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
from typing import Any, TypeVar
|
| 22 |
+
|
| 23 |
+
import attrs
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.data
|
| 26 |
+
|
| 27 |
+
from imaginaire.model import ImaginaireModel
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from megatron.core import ModelParallelConfig
|
| 31 |
+
|
| 32 |
+
USE_MEGATRON = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
USE_MEGATRON = False
|
| 35 |
+
print("Megatron-core is not installed.")
|
| 36 |
+
|
| 37 |
+
import builtins
|
| 38 |
+
|
| 39 |
+
from imaginaire.lazy_config import LazyCall as L
|
| 40 |
+
from imaginaire.lazy_config import LazyDict
|
| 41 |
+
from imaginaire.utils import callback, distributed
|
| 42 |
+
from imaginaire.utils.misc import Color
|
| 43 |
+
|
| 44 |
+
T = TypeVar("T")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _is_attrs_instance(obj: object) -> bool:
|
| 48 |
+
"""
|
| 49 |
+
Helper function to check if an object is an instance of an attrs-defined class.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
obj: The object to check.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
bool: True if the object is an instance of an attrs-defined class, False otherwise.
|
| 56 |
+
"""
|
| 57 |
+
return hasattr(obj, "__attrs_attrs__")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def make_freezable(cls: T) -> T:
|
| 61 |
+
"""
|
| 62 |
+
A decorator that adds the capability to freeze instances of an attrs-defined class.
|
| 63 |
+
|
| 64 |
+
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
|
| 65 |
+
to hack on a "_is_frozen" attribute.
|
| 66 |
+
|
| 67 |
+
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
|
| 68 |
+
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
|
| 69 |
+
any attrs-defined objects that are attributes of the class.
|
| 70 |
+
|
| 71 |
+
Usage:
|
| 72 |
+
@make_freezable
|
| 73 |
+
@attrs.define(slots=False)
|
| 74 |
+
class MyClass:
|
| 75 |
+
attribute1: int
|
| 76 |
+
attribute2: str
|
| 77 |
+
|
| 78 |
+
obj = MyClass(1, 'a')
|
| 79 |
+
obj.freeze() # Freeze the instance
|
| 80 |
+
obj.attribute1 = 2 # Raises AttributeError
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
cls: The class to be decorated.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
The decorated class with added freezing capability.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
if not hasattr(cls, "__dict__"):
|
| 90 |
+
raise TypeError(
|
| 91 |
+
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
|
| 92 |
+
"class was defined with `@attrs.define(slots=False)`"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
original_setattr = cls.__setattr__
|
| 96 |
+
|
| 97 |
+
def setattr_override(self, key, value) -> None:
|
| 98 |
+
"""
|
| 99 |
+
Override __setattr__ to allow modifications during initialization
|
| 100 |
+
and prevent modifications once the instance is frozen.
|
| 101 |
+
"""
|
| 102 |
+
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
|
| 103 |
+
raise AttributeError("Cannot modify frozen instance")
|
| 104 |
+
original_setattr(self, key, value) # type: ignore
|
| 105 |
+
|
| 106 |
+
cls.__setattr__ = setattr_override # type: ignore
|
| 107 |
+
|
| 108 |
+
def freeze(self: object) -> None:
|
| 109 |
+
"""
|
| 110 |
+
Freeze the instance and all its attrs-defined attributes.
|
| 111 |
+
"""
|
| 112 |
+
for _, value in attrs.asdict(self, recurse=False).items():
|
| 113 |
+
if _is_attrs_instance(value) and hasattr(value, "freeze"):
|
| 114 |
+
value.freeze()
|
| 115 |
+
self._is_frozen = True # type: ignore
|
| 116 |
+
|
| 117 |
+
cls.freeze = freeze # type: ignore
|
| 118 |
+
|
| 119 |
+
return cls
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
|
| 123 |
+
"""
|
| 124 |
+
Recursively pretty prints attrs objects with color.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
assert attrs.has(obj.__class__)
|
| 128 |
+
|
| 129 |
+
lines: list[str] = []
|
| 130 |
+
for attribute in attrs.fields(obj.__class__):
|
| 131 |
+
value = getattr(obj, attribute.name)
|
| 132 |
+
if attrs.has(value.__class__):
|
| 133 |
+
if use_color:
|
| 134 |
+
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
|
| 135 |
+
else:
|
| 136 |
+
lines.append(" " * indent + "* " + attribute.name + ":")
|
| 137 |
+
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
|
| 138 |
+
else:
|
| 139 |
+
if use_color:
|
| 140 |
+
lines.append(
|
| 141 |
+
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
|
| 145 |
+
return "\n".join(lines)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def pretty_print_overrides(overrides: list[str] | None = None, use_color: bool = False) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Pretty prints overrides.
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
lines: list[str] = []
|
| 154 |
+
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
|
| 155 |
+
for override in overrides:
|
| 156 |
+
if override == "--":
|
| 157 |
+
continue
|
| 158 |
+
if override.startswith("~"):
|
| 159 |
+
attribute_name = override[1:]
|
| 160 |
+
attribute_value = None
|
| 161 |
+
else:
|
| 162 |
+
attribute_name, attribute_value = override.split("=")
|
| 163 |
+
if use_color:
|
| 164 |
+
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
|
| 165 |
+
else:
|
| 166 |
+
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
|
| 167 |
+
|
| 168 |
+
return "\n".join(lines)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@make_freezable
|
| 172 |
+
@attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
|
| 173 |
+
class ObjectStoreConfig:
|
| 174 |
+
# Whether the file I/O is from object store instead of local disk.
|
| 175 |
+
enabled: bool = False
|
| 176 |
+
# Path to the object store credentials file.
|
| 177 |
+
credentials: str = ""
|
| 178 |
+
# Object store bucket to read from / write to the objects.
|
| 179 |
+
bucket: str = ""
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
@make_freezable
|
| 183 |
+
@attrs.define(slots=False)
|
| 184 |
+
class JobConfig:
|
| 185 |
+
# Project name.
|
| 186 |
+
project: str = ""
|
| 187 |
+
# Experiment name.
|
| 188 |
+
group: str = ""
|
| 189 |
+
# Run/job name.
|
| 190 |
+
name: str = ""
|
| 191 |
+
|
| 192 |
+
@property
|
| 193 |
+
def path(self) -> str:
|
| 194 |
+
return f"{self.project}/{self.group}/{self.name}"
|
| 195 |
+
|
| 196 |
+
@property
|
| 197 |
+
def path_local(self) -> str:
|
| 198 |
+
local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "checkpoints")
|
| 199 |
+
return f"{local_root}/{self.path}"
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@make_freezable
|
| 203 |
+
@attrs.define(slots=False)
|
| 204 |
+
class EMAConfig:
|
| 205 |
+
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 206 |
+
enabled: bool = False
|
| 207 |
+
# EMA decay rate.
|
| 208 |
+
beta: float = 0.9999
|
| 209 |
+
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 210 |
+
torch_compile_buffer_renaming: bool = False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@make_freezable
|
| 214 |
+
@attrs.define(slots=False)
|
| 215 |
+
class PowerEMAConfig:
|
| 216 |
+
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 217 |
+
enabled: bool = False
|
| 218 |
+
# EDM2 paper EMA decay rate.
|
| 219 |
+
s: float = 0.1
|
| 220 |
+
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 221 |
+
torch_compile_buffer_renaming: bool = False
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
@make_freezable
|
| 225 |
+
@attrs.define(slots=False)
|
| 226 |
+
class DDPConfig:
|
| 227 |
+
# Traverse the computation graph to find parameters that don't receive gradients.
|
| 228 |
+
find_unused_parameters: bool = False
|
| 229 |
+
# Set to True if the computation graph does not change during the whole training loop.
|
| 230 |
+
static_graph: bool = True
|
| 231 |
+
# Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
|
| 232 |
+
broadcast_buffers: bool = True
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@make_freezable
|
| 236 |
+
@attrs.define(slots=False)
|
| 237 |
+
class CuDNNConfig:
|
| 238 |
+
# Set to True for better reproducibility of the results (only using deterministic cudnn functions).
|
| 239 |
+
deterministic: bool = False
|
| 240 |
+
# If set to True, cudnn will benchmark several algorithms and pick the fastest one.
|
| 241 |
+
benchmark: bool = True
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@make_freezable
|
| 245 |
+
@attrs.define(slots=False)
|
| 246 |
+
class JITConfig:
|
| 247 |
+
# Enable exporting a JIT compiled model.
|
| 248 |
+
enabled: bool = False
|
| 249 |
+
# Input tensor shape, for example input.
|
| 250 |
+
input_shape: list[int] | None = None
|
| 251 |
+
# Device to compile onto.
|
| 252 |
+
device: str = "cuda"
|
| 253 |
+
# # Data type to compile onto.
|
| 254 |
+
dtype: str = "bfloat16"
|
| 255 |
+
# Strict mode for PyTorch JIT.
|
| 256 |
+
strict: bool = True
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@make_freezable
|
| 260 |
+
@attrs.define(slots=False)
|
| 261 |
+
class CheckpointConfig:
|
| 262 |
+
# possible checkpoint class
|
| 263 |
+
type: dict | None = None
|
| 264 |
+
# for dcp, whether to use async mode
|
| 265 |
+
dcp_async_mode_enabled: bool = False
|
| 266 |
+
# Save the checkpoint every N iterations.
|
| 267 |
+
save_iter: int = 999999999
|
| 268 |
+
# Path of model weights to resume the checkpoint from.
|
| 269 |
+
load_path: str = ""
|
| 270 |
+
# Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
|
| 271 |
+
load_training_state: bool = False
|
| 272 |
+
# Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
|
| 273 |
+
only_load_scheduler_state: bool = False
|
| 274 |
+
# Load state_dict to the models in strict mode.
|
| 275 |
+
strict_resume: bool = True
|
| 276 |
+
# Configs for JIT compiling EMA model.
|
| 277 |
+
jit: JITConfig = attrs.field(factory=JITConfig)
|
| 278 |
+
# Print detailed information during checkpoint saving/loading.
|
| 279 |
+
verbose: bool = True
|
| 280 |
+
# keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
|
| 281 |
+
keys_not_to_resume: list[str] = [] # noqa: RUF008
|
| 282 |
+
# Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
|
| 283 |
+
broadcast_via_filesystem: bool = False
|
| 284 |
+
load_ema_to_reg: bool = False
|
| 285 |
+
# In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
|
| 286 |
+
dcp_allow_mismatched_size: bool = False
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@make_freezable
|
| 290 |
+
@attrs.define(slots=False)
|
| 291 |
+
class NVTXConfig:
|
| 292 |
+
"""Config for NVTX ranges used in the main training loop.
|
| 293 |
+
|
| 294 |
+
See tutorials/nanogpt for more details on how to integrate profiling into your model."""
|
| 295 |
+
|
| 296 |
+
# Enable the NVTX ranges.
|
| 297 |
+
enabled: bool = False
|
| 298 |
+
# Synchronize everything in each NVTX range.
|
| 299 |
+
cuda_synchronize: bool = False
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@make_freezable
|
| 303 |
+
@attrs.define(slots=False)
|
| 304 |
+
class Profiling:
|
| 305 |
+
enable_profiling: bool = False
|
| 306 |
+
enable_memory_snapshot: bool = False
|
| 307 |
+
profile_freq: int = 1
|
| 308 |
+
first_n_rank: int = 8 # -1 means all ranks, n means first n ranks dumpy profiling info
|
| 309 |
+
record_shape: bool = True
|
| 310 |
+
profile_memory: bool = True
|
| 311 |
+
with_stack: bool = True
|
| 312 |
+
with_modules: bool = True
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@make_freezable
|
| 316 |
+
@attrs.define(slots=False)
|
| 317 |
+
class TrainerConfig:
|
| 318 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 319 |
+
|
| 320 |
+
type: builtins.type[ImaginaireTrainer] = ImaginaireTrainer
|
| 321 |
+
# Set the callback class.
|
| 322 |
+
# Defaults to the callbacks below.
|
| 323 |
+
callbacks: LazyDict[dict[str, callback.Callback]] = LazyDict( # noqa: RUF009
|
| 324 |
+
dict(
|
| 325 |
+
ema=L(callback.EMAModelCallback)(),
|
| 326 |
+
progress_bar=L(callback.ProgressBarCallback)(),
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
# distributed parallelism strategy
|
| 330 |
+
distributed_parallelism: str = "ddp"
|
| 331 |
+
# Distributed data parallel configs.
|
| 332 |
+
ddp: DDPConfig = attrs.field(factory=DDPConfig)
|
| 333 |
+
# cuDNN configs.
|
| 334 |
+
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
|
| 335 |
+
# Set the random seed.
|
| 336 |
+
seed: int = 0
|
| 337 |
+
# Gradient scaler arguments (for torch.amp.GradScaler).
|
| 338 |
+
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
|
| 339 |
+
# Maximum number of iterations to train the model.
|
| 340 |
+
max_iter: int = 999999999
|
| 341 |
+
# Maximum number of iterations to validate the model. If None, validate on the entire dataset.
|
| 342 |
+
max_val_iter: int | None = None
|
| 343 |
+
# How often we log the training stats.
|
| 344 |
+
logging_iter: int = 100
|
| 345 |
+
# Whether we want to run the validation routines.
|
| 346 |
+
run_validation: bool = True
|
| 347 |
+
# How often we evaluate on the validation set.
|
| 348 |
+
validation_iter: int = 999999999
|
| 349 |
+
# Kill the process after N seconds since the last iteration (usually means dead job).
|
| 350 |
+
timeout_period: int = 999999999
|
| 351 |
+
# Tensor memory organization format.
|
| 352 |
+
memory_format: torch.memory_format = torch.preserve_format
|
| 353 |
+
# Gradient accumulation (update step every N iteration).
|
| 354 |
+
grad_accum_iter: int = 1
|
| 355 |
+
# Profiling config
|
| 356 |
+
profiling: Profiling = attrs.field(factory=Profiling)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@make_freezable
|
| 360 |
+
@attrs.define(slots=False)
|
| 361 |
+
class Config:
|
| 362 |
+
"""Config for an imaginaire4 job.
|
| 363 |
+
|
| 364 |
+
See /README.md/Configuration System for more info.
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
# Model configs.
|
| 368 |
+
model: LazyDict[ImaginaireModel]
|
| 369 |
+
# Optimizer configs.
|
| 370 |
+
optimizer: LazyDict[torch.optim.Optimizer]
|
| 371 |
+
# Scheduler configs.
|
| 372 |
+
scheduler: LazyDict[torch.optim.lr_scheduler.LRScheduler]
|
| 373 |
+
# Training data configs.
|
| 374 |
+
dataloader_train: LazyDict[torch.utils.data.DataLoader]
|
| 375 |
+
# Validation data configs.
|
| 376 |
+
dataloader_val: LazyDict[torch.utils.data.DataLoader]
|
| 377 |
+
|
| 378 |
+
# Training job configs.
|
| 379 |
+
job: JobConfig = attrs.field(factory=JobConfig)
|
| 380 |
+
|
| 381 |
+
# Trainer configs.
|
| 382 |
+
trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
|
| 383 |
+
|
| 384 |
+
if USE_MEGATRON:
|
| 385 |
+
# Megatron-Core configs
|
| 386 |
+
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
|
| 387 |
+
else:
|
| 388 |
+
model_parallel: None = None
|
| 389 |
+
|
| 390 |
+
# Checkpointer configs.
|
| 391 |
+
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
|
| 392 |
+
|
| 393 |
+
def pretty_print(self, use_color: bool = False) -> str:
|
| 394 |
+
return _pretty_print_attrs_instance(self, 0, use_color)
|
| 395 |
+
|
| 396 |
+
def to_dict(self) -> dict[str, Any]:
|
| 397 |
+
return attrs.asdict(self)
|
| 398 |
+
|
| 399 |
+
def validate(self) -> None:
|
| 400 |
+
"""Validate that the config has all required fields."""
|
| 401 |
+
|
| 402 |
+
# broadcast job.name across all ranks to make sure it is consistent
|
| 403 |
+
# otherwise, unaligned job names leads unaligned path to save checkpoints
|
| 404 |
+
job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
|
| 405 |
+
distributed.broadcast(job_name_tensor, 0)
|
| 406 |
+
self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
|
| 407 |
+
|
| 408 |
+
assert self.job.project != ""
|
| 409 |
+
assert self.job.group != ""
|
| 410 |
+
assert self.job.name != ""
|
imaginaire/lazy_config/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
|
| 20 |
+
from imaginaire.lazy_config.instantiate import instantiate
|
| 21 |
+
from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict
|
| 22 |
+
from imaginaire.lazy_config.omegaconf_patch import to_object
|
| 23 |
+
|
| 24 |
+
OmegaConf.to_object = to_object
|
| 25 |
+
|
| 26 |
+
PLACEHOLDER = None
|
| 27 |
+
|
| 28 |
+
__all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fixup_module_metadata(module_name, namespace, keys=None):
|
| 35 |
+
"""
|
| 36 |
+
Fix the __qualname__ of module members to be their exported api name, so
|
| 37 |
+
when they are referenced in docs, sphinx can find them. Reference:
|
| 38 |
+
https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
|
| 39 |
+
"""
|
| 40 |
+
if not DOC_BUILDING:
|
| 41 |
+
return
|
| 42 |
+
seen_ids = set()
|
| 43 |
+
|
| 44 |
+
def fix_one(qualname, name, obj):
|
| 45 |
+
# avoid infinite recursion (relevant when using
|
| 46 |
+
# typing.Generic, for example)
|
| 47 |
+
if id(obj) in seen_ids:
|
| 48 |
+
return
|
| 49 |
+
seen_ids.add(id(obj))
|
| 50 |
+
|
| 51 |
+
mod = getattr(obj, "__module__", None)
|
| 52 |
+
if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
|
| 53 |
+
obj.__module__ = module_name
|
| 54 |
+
# Modules, unlike everything else in Python, put fully-qualitied
|
| 55 |
+
# names into their __name__ attribute. We check for "." to avoid
|
| 56 |
+
# rewriting these.
|
| 57 |
+
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
| 58 |
+
obj.__name__ = name
|
| 59 |
+
obj.__qualname__ = qualname
|
| 60 |
+
if isinstance(obj, type):
|
| 61 |
+
for attr_name, attr_value in obj.__dict__.items():
|
| 62 |
+
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
| 63 |
+
|
| 64 |
+
if keys is None:
|
| 65 |
+
keys = namespace.keys()
|
| 66 |
+
for objname in keys:
|
| 67 |
+
if not objname.startswith("_"):
|
| 68 |
+
obj = namespace[objname]
|
| 69 |
+
fix_one(objname, objname, obj)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
fixup_module_metadata(__name__, globals(), __all__)
|
| 73 |
+
del fixup_module_metadata
|
imaginaire/lazy_config/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
imaginaire/lazy_config/__pycache__/file_io.cpython-310.pyc
ADDED
|
Binary file (387 Bytes). View file
|
|
|
imaginaire/lazy_config/__pycache__/instantiate.cpython-310.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
imaginaire/lazy_config/__pycache__/lazy.cpython-310.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
imaginaire/lazy_config/__pycache__/omegaconf_patch.cpython-310.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
imaginaire/lazy_config/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
imaginaire/lazy_config/file_io.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
|
| 17 |
+
from iopath.common.file_io import PathManager as PathManagerBase
|
| 18 |
+
|
| 19 |
+
__all__ = ["PathHandler", "PathManager"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PathManager = PathManagerBase()
|
| 23 |
+
PathManager.register_handler(HTTPURLHandler())
|
| 24 |
+
PathManager.register_handler(OneDrivePathHandler())
|
imaginaire/lazy_config/instantiate.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import collections.abc as abc
|
| 17 |
+
import dataclasses
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import attrs
|
| 21 |
+
|
| 22 |
+
from imaginaire.lazy_config.registry import _convert_target_to_string, locate
|
| 23 |
+
from imaginaire.utils import log
|
| 24 |
+
|
| 25 |
+
__all__ = ["dump_dataclass", "instantiate"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_dataclass_or_attrs(target):
|
| 29 |
+
return dataclasses.is_dataclass(target) or attrs.has(target)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def dump_dataclass(obj: Any):
|
| 33 |
+
"""
|
| 34 |
+
Dump a dataclass recursively into a dict that can be later instantiated.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
obj: a dataclass object
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
dict
|
| 41 |
+
"""
|
| 42 |
+
assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
|
| 43 |
+
"dump_dataclass() requires an instance of a dataclass."
|
| 44 |
+
)
|
| 45 |
+
ret = {"_target_": _convert_target_to_string(type(obj))}
|
| 46 |
+
for f in dataclasses.fields(obj):
|
| 47 |
+
v = getattr(obj, f.name)
|
| 48 |
+
if dataclasses.is_dataclass(v):
|
| 49 |
+
v = dump_dataclass(v)
|
| 50 |
+
if isinstance(v, (list, tuple)):
|
| 51 |
+
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
|
| 52 |
+
ret[f.name] = v
|
| 53 |
+
return ret
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def instantiate(cfg, *args, **kwargs):
|
| 57 |
+
"""
|
| 58 |
+
Recursively instantiate objects defined in dictionaries by
|
| 59 |
+
"_target_" and arguments.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
cfg: a dict-like object with "_target_" that defines the caller, and
|
| 63 |
+
other keys that define the arguments
|
| 64 |
+
args: Optional positional parameters pass-through.
|
| 65 |
+
kwargs: Optional named parameters pass-through.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
object instantiated by cfg
|
| 69 |
+
"""
|
| 70 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 71 |
+
|
| 72 |
+
if isinstance(cfg, ListConfig):
|
| 73 |
+
lst = [instantiate(x) for x in cfg]
|
| 74 |
+
return ListConfig(lst, flags={"allow_objects": True})
|
| 75 |
+
if isinstance(cfg, list):
|
| 76 |
+
# Specialize for list, because many classes take
|
| 77 |
+
# list[objects] as arguments, such as ResNet, DatasetMapper
|
| 78 |
+
return [instantiate(x) for x in cfg]
|
| 79 |
+
|
| 80 |
+
# If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
|
| 81 |
+
# instantiate it to the actual dataclass.
|
| 82 |
+
if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
|
| 83 |
+
return OmegaConf.to_object(cfg)
|
| 84 |
+
|
| 85 |
+
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
|
| 86 |
+
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
|
| 87 |
+
# but faster: https://github.com/facebookresearch/hydra/issues/1200
|
| 88 |
+
is_recursive = getattr(cfg, "_recursive_", True)
|
| 89 |
+
if is_recursive:
|
| 90 |
+
cfg = {k: instantiate(v) for k, v in cfg.items()}
|
| 91 |
+
else:
|
| 92 |
+
cfg = {k: v for k, v in cfg.items()}
|
| 93 |
+
# pop the _recursive_ key to avoid passing it as a parameter
|
| 94 |
+
if "_recursive_" in cfg:
|
| 95 |
+
cfg.pop("_recursive_")
|
| 96 |
+
cls = cfg.pop("_target_")
|
| 97 |
+
cls = instantiate(cls)
|
| 98 |
+
|
| 99 |
+
if isinstance(cls, str):
|
| 100 |
+
cls_name = cls
|
| 101 |
+
cls = locate(cls_name)
|
| 102 |
+
assert cls is not None, cls_name
|
| 103 |
+
else:
|
| 104 |
+
try:
|
| 105 |
+
cls_name = cls.__module__ + "." + cls.__qualname__
|
| 106 |
+
except Exception:
|
| 107 |
+
# target could be anything, so the above could fail
|
| 108 |
+
cls_name = str(cls)
|
| 109 |
+
assert callable(cls), f"_target_ {cls} does not define a callable object"
|
| 110 |
+
try:
|
| 111 |
+
# override config with kwargs
|
| 112 |
+
instantiate_kwargs = {}
|
| 113 |
+
instantiate_kwargs.update(cfg)
|
| 114 |
+
instantiate_kwargs.update(kwargs)
|
| 115 |
+
return cls(*args, **instantiate_kwargs)
|
| 116 |
+
except TypeError:
|
| 117 |
+
log.error(f"Error when instantiating {cls_name}!")
|
| 118 |
+
raise
|
| 119 |
+
return cfg # return as-is if don't know what to do
|
imaginaire/lazy_config/lazy.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import ast
|
| 17 |
+
import builtins
|
| 18 |
+
import collections.abc as abc
|
| 19 |
+
import importlib
|
| 20 |
+
import inspect
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import pickle
|
| 24 |
+
import uuid
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
from contextlib import contextmanager
|
| 27 |
+
from copy import deepcopy
|
| 28 |
+
from dataclasses import is_dataclass
|
| 29 |
+
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast
|
| 30 |
+
|
| 31 |
+
import attrs
|
| 32 |
+
import yaml
|
| 33 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 34 |
+
|
| 35 |
+
from imaginaire.utils import log
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import dill as dill_pickle
|
| 39 |
+
except ImportError:
|
| 40 |
+
dill_pickle = None
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import cloudpickle
|
| 44 |
+
except ImportError:
|
| 45 |
+
cloudpickle = None
|
| 46 |
+
|
| 47 |
+
from imaginaire.lazy_config.file_io import PathManager
|
| 48 |
+
from imaginaire.lazy_config.registry import _convert_target_to_string
|
| 49 |
+
|
| 50 |
+
__all__ = ["LazyCall", "LazyConfig", "LazyDict"]
|
| 51 |
+
|
| 52 |
+
T = TypeVar("T")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def sort_dict(d: dict[str, Any]) -> OrderedDict[str, Any]:
|
| 56 |
+
return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
|
| 60 |
+
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def sort_recursive(obj: dict[str, Any] | list[Any] | Any) -> OrderedDict[str, Any] | list[Any] | Any:
|
| 64 |
+
if isinstance(obj, dict):
|
| 65 |
+
return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
|
| 66 |
+
elif isinstance(obj, list):
|
| 67 |
+
return [sort_recursive(item) for item in obj]
|
| 68 |
+
return obj
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
yaml.add_representer(OrderedDict, dict_representer)
|
| 72 |
+
|
| 73 |
+
OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
|
| 74 |
+
OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_default_params(cls_or_func):
|
| 78 |
+
if callable(cls_or_func):
|
| 79 |
+
# inspect signature for function
|
| 80 |
+
signature = inspect.signature(cls_or_func)
|
| 81 |
+
else:
|
| 82 |
+
# inspect signature for class
|
| 83 |
+
signature = inspect.signature(cls_or_func.__init__)
|
| 84 |
+
params = signature.parameters
|
| 85 |
+
default_params = {
|
| 86 |
+
name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
|
| 87 |
+
}
|
| 88 |
+
return default_params
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if TYPE_CHECKING:
|
| 92 |
+
# Have `LazyDict[T]` behave as `T`, so that attribute access works. Ideally, it
|
| 93 |
+
# would be a subclass of `T`, but this doesn't seem to be possible in the type
|
| 94 |
+
# system yet.
|
| 95 |
+
LazyDict: TypeAlias = T
|
| 96 |
+
else:
|
| 97 |
+
LazyDict = DictConfig
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class LazyCall(Generic[T]):
|
| 101 |
+
"""
|
| 102 |
+
Wrap a callable so that when it's called, the call will not be executed,
|
| 103 |
+
but returns a dict that describes the call.
|
| 104 |
+
|
| 105 |
+
LazyCall object has to be called with only keyword arguments. Positional
|
| 106 |
+
arguments are not yet supported.
|
| 107 |
+
|
| 108 |
+
Examples:
|
| 109 |
+
::
|
| 110 |
+
from detectron2.config import instantiate, LazyCall
|
| 111 |
+
|
| 112 |
+
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
|
| 113 |
+
layer_cfg.out_channels = 64 # can edit it afterwards
|
| 114 |
+
layer = instantiate(layer_cfg)
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, target: type[T]):
|
| 118 |
+
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
|
| 119 |
+
raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
|
| 120 |
+
self._target = target
|
| 121 |
+
|
| 122 |
+
def __call__(self, **kwargs) -> LazyDict[T]:
|
| 123 |
+
if is_dataclass(self._target) or attrs.has(self._target):
|
| 124 |
+
# omegaconf object cannot hold dataclass type
|
| 125 |
+
# https://github.com/omry/omegaconf/issues/784
|
| 126 |
+
target = _convert_target_to_string(self._target)
|
| 127 |
+
else:
|
| 128 |
+
target = self._target
|
| 129 |
+
kwargs["_target_"] = target
|
| 130 |
+
|
| 131 |
+
_final_params = get_default_params(self._target)
|
| 132 |
+
_final_params.update(kwargs)
|
| 133 |
+
|
| 134 |
+
return cast(LazyDict[T], DictConfig(content=_final_params, flags={"allow_objects": True}))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _visit_dict_config(cfg, func):
|
| 138 |
+
"""
|
| 139 |
+
Apply func recursively to all DictConfig in cfg.
|
| 140 |
+
"""
|
| 141 |
+
if isinstance(cfg, DictConfig):
|
| 142 |
+
func(cfg)
|
| 143 |
+
for v in cfg.values():
|
| 144 |
+
_visit_dict_config(v, func)
|
| 145 |
+
elif isinstance(cfg, ListConfig):
|
| 146 |
+
for v in cfg:
|
| 147 |
+
_visit_dict_config(v, func)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _validate_py_syntax(filename):
|
| 151 |
+
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
| 152 |
+
with PathManager.open(filename, "r") as f:
|
| 153 |
+
content = f.read()
|
| 154 |
+
try:
|
| 155 |
+
ast.parse(content)
|
| 156 |
+
except SyntaxError as e:
|
| 157 |
+
raise SyntaxError(f"Config file {filename} has syntax error!") from e
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _cast_to_config(obj):
|
| 161 |
+
# if given a dict, return DictConfig instead
|
| 162 |
+
if isinstance(obj, dict):
|
| 163 |
+
return DictConfig(obj, flags={"allow_objects": True})
|
| 164 |
+
return obj
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
|
| 168 |
+
"""
|
| 169 |
+
A namespace to put all imported config into.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _random_package_name(filename):
|
| 174 |
+
# generate a random package name when loading config files
|
| 175 |
+
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@contextmanager
|
| 179 |
+
def _patch_import():
|
| 180 |
+
"""
|
| 181 |
+
Enhance relative import statements in config files, so that they:
|
| 182 |
+
1. locate files purely based on relative location, regardless of packages.
|
| 183 |
+
e.g. you can import file without having __init__
|
| 184 |
+
2. do not cache modules globally; modifications of module states has no side effect
|
| 185 |
+
3. support other storage system through PathManager, so config files can be in the cloud
|
| 186 |
+
4. imported dict are turned into omegaconf.DictConfig automatically
|
| 187 |
+
"""
|
| 188 |
+
old_import = builtins.__import__
|
| 189 |
+
|
| 190 |
+
def find_relative_file(original_file, relative_import_path, level):
|
| 191 |
+
# NOTE: "from . import x" is not handled. Because then it's unclear
|
| 192 |
+
# if such import should produce `x` as a python module or DictConfig.
|
| 193 |
+
# This can be discussed further if needed.
|
| 194 |
+
relative_import_err = """
|
| 195 |
+
Relative import of directories is not allowed within config files.
|
| 196 |
+
Within a config file, relative import can only import other config files.
|
| 197 |
+
""".replace("\n", " ")
|
| 198 |
+
if not len(relative_import_path):
|
| 199 |
+
raise ImportError(relative_import_err)
|
| 200 |
+
|
| 201 |
+
cur_file = os.path.dirname(original_file)
|
| 202 |
+
for _ in range(level - 1):
|
| 203 |
+
cur_file = os.path.dirname(cur_file)
|
| 204 |
+
cur_name = relative_import_path.lstrip(".")
|
| 205 |
+
for part in cur_name.split("."):
|
| 206 |
+
cur_file = os.path.join(cur_file, part)
|
| 207 |
+
if not cur_file.endswith(".py"):
|
| 208 |
+
cur_file += ".py"
|
| 209 |
+
if not PathManager.isfile(cur_file):
|
| 210 |
+
cur_file_no_suffix = cur_file[: -len(".py")]
|
| 211 |
+
if PathManager.isdir(cur_file_no_suffix):
|
| 212 |
+
raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
|
| 213 |
+
else:
|
| 214 |
+
raise ImportError(
|
| 215 |
+
f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
|
| 216 |
+
)
|
| 217 |
+
return cur_file
|
| 218 |
+
|
| 219 |
+
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 220 |
+
if (
|
| 221 |
+
# Only deal with relative imports inside config files
|
| 222 |
+
level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
|
| 223 |
+
):
|
| 224 |
+
cur_file = find_relative_file(globals["__file__"], name, level)
|
| 225 |
+
_validate_py_syntax(cur_file)
|
| 226 |
+
spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
|
| 227 |
+
module = importlib.util.module_from_spec(spec)
|
| 228 |
+
module.__file__ = cur_file
|
| 229 |
+
with PathManager.open(cur_file) as f:
|
| 230 |
+
content = f.read()
|
| 231 |
+
exec(compile(content, cur_file, "exec"), module.__dict__)
|
| 232 |
+
for name in fromlist: # turn imported dict into DictConfig automatically
|
| 233 |
+
val = _cast_to_config(module.__dict__[name])
|
| 234 |
+
module.__dict__[name] = val
|
| 235 |
+
return module
|
| 236 |
+
return old_import(name, globals, locals, fromlist=fromlist, level=level)
|
| 237 |
+
|
| 238 |
+
builtins.__import__ = new_import
|
| 239 |
+
yield new_import
|
| 240 |
+
builtins.__import__ = old_import
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class LazyConfig:
|
| 244 |
+
"""
|
| 245 |
+
Provide methods to save, load, and overrides an omegaconf config object
|
| 246 |
+
which may contain definition of lazily-constructed objects.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
@staticmethod
|
| 250 |
+
def load_rel(filename: str, keys: None | str | tuple[str, ...] = None):
|
| 251 |
+
"""
|
| 252 |
+
Similar to :meth:`load()`, but load path relative to the caller's
|
| 253 |
+
source file.
|
| 254 |
+
|
| 255 |
+
This has the same functionality as a relative import, except that this method
|
| 256 |
+
accepts filename as a string, so more characters are allowed in the filename.
|
| 257 |
+
"""
|
| 258 |
+
caller_frame = inspect.stack()[1]
|
| 259 |
+
caller_fname = caller_frame[0].f_code.co_filename
|
| 260 |
+
assert caller_fname != "<string>", "load_rel Unable to find caller"
|
| 261 |
+
caller_dir = os.path.dirname(caller_fname)
|
| 262 |
+
filename = os.path.join(caller_dir, filename)
|
| 263 |
+
return LazyConfig.load(filename, keys)
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def load(filename: str, keys: None | str | tuple[str, ...] = None):
|
| 267 |
+
"""
|
| 268 |
+
Load a config file.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
filename: absolute path or relative path w.r.t. the current working directory
|
| 272 |
+
keys: keys to load and return. If not given, return all keys
|
| 273 |
+
(whose values are config objects) in a dict.
|
| 274 |
+
"""
|
| 275 |
+
has_keys = keys is not None
|
| 276 |
+
filename = filename.replace("/./", "/") # redundant
|
| 277 |
+
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
|
| 278 |
+
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
|
| 279 |
+
if filename.endswith(".py"):
|
| 280 |
+
_validate_py_syntax(filename)
|
| 281 |
+
|
| 282 |
+
with _patch_import():
|
| 283 |
+
# Record the filename
|
| 284 |
+
module_namespace = {
|
| 285 |
+
"__file__": filename,
|
| 286 |
+
"__package__": _random_package_name(filename),
|
| 287 |
+
}
|
| 288 |
+
with PathManager.open(filename) as f:
|
| 289 |
+
content = f.read()
|
| 290 |
+
# Compile first with filename to:
|
| 291 |
+
# 1. make filename appears in stacktrace
|
| 292 |
+
# 2. make load_rel able to find its parent's (possibly remote) location
|
| 293 |
+
exec(compile(content, filename, "exec"), module_namespace)
|
| 294 |
+
|
| 295 |
+
ret = module_namespace
|
| 296 |
+
else:
|
| 297 |
+
with PathManager.open(filename) as f:
|
| 298 |
+
obj = yaml.unsafe_load(f)
|
| 299 |
+
ret = OmegaConf.create(obj, flags={"allow_objects": True})
|
| 300 |
+
|
| 301 |
+
if has_keys:
|
| 302 |
+
if isinstance(keys, str):
|
| 303 |
+
return _cast_to_config(ret[keys])
|
| 304 |
+
else:
|
| 305 |
+
return tuple(_cast_to_config(ret[a]) for a in keys)
|
| 306 |
+
else:
|
| 307 |
+
if filename.endswith(".py"):
|
| 308 |
+
# when not specified, only load those that are config objects
|
| 309 |
+
ret = DictConfig(
|
| 310 |
+
{
|
| 311 |
+
name: _cast_to_config(value)
|
| 312 |
+
for name, value in ret.items()
|
| 313 |
+
if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
|
| 314 |
+
},
|
| 315 |
+
flags={"allow_objects": True},
|
| 316 |
+
)
|
| 317 |
+
return ret
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def save_pkl(cfg, filename: str) -> str:
|
| 321 |
+
"""
|
| 322 |
+
Saves a Config object to a file using pickle serialization. This method is typically used
|
| 323 |
+
when the configuration object contains complex objects, such as lambdas, that are not supported by
|
| 324 |
+
simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
|
| 325 |
+
object before serialization to ensure that the original object remains unmodified.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
cfg: A Config object to be serialized and saved.
|
| 329 |
+
filename: The path and name of the file where the configuration should be saved. The function
|
| 330 |
+
assumes the file extension indicates a pickle format (e.g., .pkl).
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
str: The filename to which the configuration was saved. This can be used to verify the file location
|
| 334 |
+
or log the outcome.
|
| 335 |
+
|
| 336 |
+
Notes:
|
| 337 |
+
- The function logs a warning if the configuration is successfully saved using pickle.
|
| 338 |
+
- If saving fails, an error is logged with the exception details.
|
| 339 |
+
"""
|
| 340 |
+
try:
|
| 341 |
+
cfg = deepcopy(cfg)
|
| 342 |
+
except Exception:
|
| 343 |
+
pass
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
with PathManager.open(filename, "wb") as f:
|
| 347 |
+
pickle.dump(cfg, f)
|
| 348 |
+
log.warning(f"Config is saved using pickle at {filename}.")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
log.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
|
| 351 |
+
if dill_pickle:
|
| 352 |
+
try:
|
| 353 |
+
with PathManager.open(filename, "wb") as f:
|
| 354 |
+
pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
|
| 355 |
+
log.warning(f"Config is saved using dill at {filename}.")
|
| 356 |
+
except Exception as e:
|
| 357 |
+
log.error(f"Failed to save config to {filename}: {e}.")
|
| 358 |
+
if cloudpickle:
|
| 359 |
+
try:
|
| 360 |
+
with PathManager.open(filename, "wb") as f:
|
| 361 |
+
pickle.dump(cloudpickle.dumps(cfg), f)
|
| 362 |
+
log.warning(f"Config is saved using cloudpickle at {filename}.")
|
| 363 |
+
except Exception as e:
|
| 364 |
+
log.error(f"Failed to save config to {filename}: {e}.")
|
| 365 |
+
else:
|
| 366 |
+
log.error("cloudpickle is not available. Cannot save the config.")
|
| 367 |
+
raise e
|
| 368 |
+
|
| 369 |
+
return filename
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def save_yaml(cfg, filename: str) -> str:
|
| 373 |
+
"""
|
| 374 |
+
Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
|
| 378 |
+
filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
|
| 382 |
+
|
| 383 |
+
Notes:
|
| 384 |
+
- The function logs a warning if the configuration is successfully saved using YAML.
|
| 385 |
+
- If saving fails, an error is logged with the exception details.
|
| 386 |
+
"""
|
| 387 |
+
logger = logging.getLogger(__name__)
|
| 388 |
+
try:
|
| 389 |
+
cfg = deepcopy(cfg)
|
| 390 |
+
except Exception:
|
| 391 |
+
pass
|
| 392 |
+
|
| 393 |
+
# Define a function to check if an item is serializable to YAML
|
| 394 |
+
def is_serializable(item):
|
| 395 |
+
try:
|
| 396 |
+
OmegaConf.to_yaml(item)
|
| 397 |
+
return True
|
| 398 |
+
except Exception as e:
|
| 399 |
+
return False
|
| 400 |
+
|
| 401 |
+
# Function to convert unserializable items to strings
|
| 402 |
+
def serialize_config(config):
|
| 403 |
+
if isinstance(config, DictConfig):
|
| 404 |
+
for key, value in config.items():
|
| 405 |
+
if isinstance(value, (DictConfig, ListConfig)):
|
| 406 |
+
try:
|
| 407 |
+
if "_target_" in value:
|
| 408 |
+
default_params = get_default_params(value["_target_"])
|
| 409 |
+
for default_key, default_v in default_params.items():
|
| 410 |
+
if default_key not in value:
|
| 411 |
+
value[default_key] = default_v
|
| 412 |
+
except Exception as e:
|
| 413 |
+
log.error(f"Failed to add default argument values: {e}")
|
| 414 |
+
|
| 415 |
+
serialize_config(value)
|
| 416 |
+
else:
|
| 417 |
+
if not is_serializable(value) and value is not None:
|
| 418 |
+
config[key] = str(value)
|
| 419 |
+
elif isinstance(config, ListConfig):
|
| 420 |
+
for i, item in enumerate(config):
|
| 421 |
+
if isinstance(item, (DictConfig, ListConfig)):
|
| 422 |
+
serialize_config(item)
|
| 423 |
+
else:
|
| 424 |
+
if not is_serializable(item) and item is not None:
|
| 425 |
+
config[i] = str(item)
|
| 426 |
+
else:
|
| 427 |
+
raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
|
| 428 |
+
return config
|
| 429 |
+
|
| 430 |
+
# Convert Config object to a DictConfig object.
|
| 431 |
+
config_dict = attrs.asdict(cfg)
|
| 432 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 433 |
+
|
| 434 |
+
# Serialize the DictConfig object by converting non-serializable objects to strings.
|
| 435 |
+
config_omegaconf = serialize_config(config_omegaconf)
|
| 436 |
+
|
| 437 |
+
config_dict: dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
|
| 438 |
+
sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
|
| 439 |
+
with open(filename, "w") as f:
|
| 440 |
+
yaml.dump(sorted_config, f, default_flow_style=False)
|
| 441 |
+
log.warning(f"Config is saved using omegaconf at {filename}.")
|
| 442 |
+
return filename
|
imaginaire/lazy_config/omegaconf_patch.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
from omegaconf.base import DictKeyType, SCMode
|
| 20 |
+
from omegaconf.dictconfig import DictConfig # pragma: no cover
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any:
|
| 24 |
+
"""
|
| 25 |
+
Converts an OmegaConf configuration object to a native Python container (dict or list), unless
|
| 26 |
+
the configuration is specifically created by LazyCall, in which case the original configuration
|
| 27 |
+
is returned directly.
|
| 28 |
+
|
| 29 |
+
This function serves as a modification of the original `to_object` method from OmegaConf,
|
| 30 |
+
preventing DictConfig objects created by LazyCall from being automatically converted to Python
|
| 31 |
+
dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
|
| 32 |
+
structure and behavior.
|
| 33 |
+
|
| 34 |
+
Differences from OmegaConf's original `to_object`:
|
| 35 |
+
- Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
|
| 36 |
+
|
| 37 |
+
Reference:
|
| 38 |
+
- Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
cfg (Any): The OmegaConf configuration object to convert.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
|
| 45 |
+
`cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
|
| 46 |
+
|
| 47 |
+
Examples:
|
| 48 |
+
>>> cfg = DictConfig({"key": "value", "_target_": "Model"})
|
| 49 |
+
>>> to_object(cfg)
|
| 50 |
+
DictConfig({"key": "value", "_target_": "Model"})
|
| 51 |
+
|
| 52 |
+
>>> cfg = DictConfig({"list": [1, 2, 3]})
|
| 53 |
+
>>> to_object(cfg)
|
| 54 |
+
{'list': [1, 2, 3]}
|
| 55 |
+
"""
|
| 56 |
+
if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
|
| 57 |
+
return cfg
|
| 58 |
+
|
| 59 |
+
return OmegaConf.to_container(
|
| 60 |
+
cfg=cfg,
|
| 61 |
+
resolve=True,
|
| 62 |
+
throw_on_missing=True,
|
| 63 |
+
enum_to_str=False,
|
| 64 |
+
structured_config_mode=SCMode.INSTANTIATE,
|
| 65 |
+
)
|
imaginaire/lazy_config/registry.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import pydoc
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from fvcore.common.registry import Registry # for backward compatibility.
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
``Registry`` and `locate` provide ways to map a string (typically found
|
| 23 |
+
in config files) to callable objects.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
__all__ = ["Registry", "locate"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _convert_target_to_string(t: Any) -> str:
|
| 30 |
+
"""
|
| 31 |
+
Inverse of ``locate()``.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
t: any object with ``__module__`` and ``__qualname__``
|
| 35 |
+
"""
|
| 36 |
+
module, qualname = t.__module__, t.__qualname__
|
| 37 |
+
|
| 38 |
+
# Compress the path to this object, e.g. ``module.submodule._impl.class``
|
| 39 |
+
# may become ``module.submodule.class``, if the later also resolves to the same
|
| 40 |
+
# object. This simplifies the string, and also is less affected by moving the
|
| 41 |
+
# class implementation.
|
| 42 |
+
module_parts = module.split(".")
|
| 43 |
+
for k in range(1, len(module_parts)):
|
| 44 |
+
prefix = ".".join(module_parts[:k])
|
| 45 |
+
candidate = f"{prefix}.{qualname}"
|
| 46 |
+
try:
|
| 47 |
+
if locate(candidate) is t:
|
| 48 |
+
return candidate
|
| 49 |
+
except ImportError:
|
| 50 |
+
pass
|
| 51 |
+
return f"{module}.{qualname}"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def locate(name: str) -> Any:
|
| 55 |
+
"""
|
| 56 |
+
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
|
| 57 |
+
such as "module.submodule.class_name".
|
| 58 |
+
|
| 59 |
+
Raise Exception if it cannot be found.
|
| 60 |
+
"""
|
| 61 |
+
obj = pydoc.locate(name)
|
| 62 |
+
|
| 63 |
+
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
|
| 64 |
+
# by pydoc.locate. Try a private function from hydra.
|
| 65 |
+
if obj is None:
|
| 66 |
+
try:
|
| 67 |
+
# from hydra.utils import get_method - will print many errors
|
| 68 |
+
from hydra.utils import _locate
|
| 69 |
+
except ImportError as e:
|
| 70 |
+
raise ImportError(f"Cannot dynamically locate object {name}!") from e
|
| 71 |
+
else:
|
| 72 |
+
obj = _locate(name) # it raises if fails
|
| 73 |
+
|
| 74 |
+
return obj
|
imaginaire/model.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from imaginaire.lazy_config import LazyDict, instantiate
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ImaginaireModel(torch.nn.Module):
|
| 24 |
+
"""The base model class of Imaginaire. It is inherited from torch.nn.Module.
|
| 25 |
+
|
| 26 |
+
All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
|
| 27 |
+
computation graphs. All inheriting child classes should implement the following methods:
|
| 28 |
+
- training_step(): The training step of the model, including the loss computation.
|
| 29 |
+
- validation_step(): The validation step of the model, including the loss computation.
|
| 30 |
+
- forward(): The computation graph for model inference.
|
| 31 |
+
The following methods have default implementations in ImaginaireModel:
|
| 32 |
+
- init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
def init_optimizer_scheduler(
|
| 39 |
+
self,
|
| 40 |
+
optimizer_config: LazyDict[torch.optim.Optimizer],
|
| 41 |
+
scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler],
|
| 42 |
+
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
|
| 43 |
+
"""Creates the optimizer and scheduler for the model.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
config_model (ModelConfig): The config object for the model.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 50 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 51 |
+
"""
|
| 52 |
+
optimizer_config.params = self.parameters()
|
| 53 |
+
optimizer = instantiate(optimizer_config)
|
| 54 |
+
scheduler_config.optimizer = optimizer
|
| 55 |
+
scheduler = instantiate(scheduler_config)
|
| 56 |
+
return optimizer, scheduler
|
| 57 |
+
|
| 58 |
+
def training_step(
|
| 59 |
+
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 60 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 61 |
+
"""The training step of the model, including the loss computation.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 65 |
+
iteration (int): Current iteration number.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
|
| 69 |
+
loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
|
| 70 |
+
"""
|
| 71 |
+
raise NotImplementedError
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def validation_step(
|
| 75 |
+
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 76 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 77 |
+
"""The validation step of the model, including the loss computation.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 81 |
+
iteration (int): Current iteration number.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
|
| 85 |
+
loss (torch.Tensor): The total loss (weighted sum of various losses).
|
| 86 |
+
"""
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
@torch.inference_mode()
|
| 90 |
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
| 91 |
+
"""The computation graph for model inference.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
*args: Whatever you decide to pass into the forward method.
|
| 95 |
+
**kwargs: Keyword arguments are also possible.
|
| 96 |
+
|
| 97 |
+
Return:
|
| 98 |
+
Your model's output.
|
| 99 |
+
"""
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
def on_model_init_start(self, set_barrier=False) -> None:
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
def on_model_init_end(self, set_barrier=False) -> None:
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
|
| 109 |
+
"""The model preparation before the training is launched
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
memory_format (torch.memory_format): Memory format of the model.
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
def on_before_zero_grad(
|
| 117 |
+
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
|
| 118 |
+
) -> None:
|
| 119 |
+
"""Hook before zero_grad() is called.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 123 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 124 |
+
iteration (int): Current iteration number.
|
| 125 |
+
"""
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def on_after_backward(self, iteration: int = 0) -> None:
|
| 129 |
+
"""Hook after loss.backward() is called.
|
| 130 |
+
|
| 131 |
+
This method is called immediately after the backward pass, allowing for custom operations
|
| 132 |
+
or modifications to be performed on the gradients before the optimizer step.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
iteration (int): Current iteration number.
|
| 136 |
+
"""
|
| 137 |
+
pass
|
imaginaire/trainer.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
import inspect
|
| 18 |
+
import os
|
| 19 |
+
import signal
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch.utils.data
|
| 24 |
+
|
| 25 |
+
from imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from megatron.core import parallel_state
|
| 29 |
+
|
| 30 |
+
USE_MEGATRON = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
USE_MEGATRON = False
|
| 33 |
+
print("Megatron-core is not installed.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
from imaginaire.lazy_config import LazyConfig, instantiate
|
| 37 |
+
from imaginaire.model import ImaginaireModel
|
| 38 |
+
from imaginaire.utils import callback, distributed, log, misc
|
| 39 |
+
from imaginaire.utils.checkpointer import Checkpointer
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ImaginaireTrainer:
|
| 43 |
+
"""The base trainer class of Imaginaire.
|
| 44 |
+
|
| 45 |
+
All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
|
| 46 |
+
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
|
| 47 |
+
mixed-precision training (fp16/bf16).
|
| 48 |
+
|
| 49 |
+
Attributes:
|
| 50 |
+
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
|
| 51 |
+
training_timer (misc.Timer): Timer object to time code blocks and functions.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
"""Constructor of the trainer.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
config (Config): The config object for the Imaginaire codebase.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.config = config
|
| 62 |
+
# Set up the distributed computing environment.
|
| 63 |
+
with misc.timer("init_distributed"):
|
| 64 |
+
distributed.init()
|
| 65 |
+
# Set up parallel states.
|
| 66 |
+
if hasattr(config.model, "context_parallel_size"):
|
| 67 |
+
if config.model_parallel.context_parallel_size > 1:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
|
| 70 |
+
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
log.critical(
|
| 74 |
+
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
|
| 75 |
+
)
|
| 76 |
+
config.model_parallel.context_parallel_size = config.model.context_parallel_size
|
| 77 |
+
if USE_MEGATRON:
|
| 78 |
+
if (
|
| 79 |
+
"create_gloo_process_groups"
|
| 80 |
+
in inspect.signature(parallel_state.initialize_model_parallel).parameters
|
| 81 |
+
):
|
| 82 |
+
parallel_state.initialize_model_parallel(
|
| 83 |
+
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 84 |
+
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 85 |
+
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 86 |
+
create_gloo_process_groups=False,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
parallel_state.initialize_model_parallel(
|
| 90 |
+
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 91 |
+
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 92 |
+
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 93 |
+
)
|
| 94 |
+
# `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
|
| 95 |
+
# It is not part of the original `parallel_state` API, so we need to set it manually.
|
| 96 |
+
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
|
| 97 |
+
if parallel_state.sequence_parallel:
|
| 98 |
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
| 99 |
+
|
| 100 |
+
# Create the local job directory, save the config file, and pipe to a local log.
|
| 101 |
+
if distributed.is_rank0():
|
| 102 |
+
os.makedirs(config.job.path_local, exist_ok=True)
|
| 103 |
+
# Save the config as .pkl for reproducibility.
|
| 104 |
+
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
|
| 105 |
+
# Save the config as .yaml for reading or parsing experiment hyperparameters.
|
| 106 |
+
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
|
| 107 |
+
dist.barrier()
|
| 108 |
+
log.init_loguru_file(f"{config.job.path_local}/stdout.log")
|
| 109 |
+
if distributed.is_rank0():
|
| 110 |
+
# Print important environment variables and the effective config.
|
| 111 |
+
log.info("Config:\n" + config.pretty_print(use_color=True))
|
| 112 |
+
misc.print_environ_variables(["TORCH_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
|
| 113 |
+
# Set the random seed. If multi-GPU, different ranks are set with different seeds.
|
| 114 |
+
misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
|
| 115 |
+
# Initialize cuDNN.
|
| 116 |
+
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
|
| 117 |
+
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
|
| 118 |
+
# Floating-point precision settings.
|
| 119 |
+
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
|
| 120 |
+
# Initialize the callback functions.
|
| 121 |
+
self.callbacks = callback.CallBackGroup(config=config, trainer=self)
|
| 122 |
+
# Initialize the model checkpointer.
|
| 123 |
+
if config.checkpoint.type is None:
|
| 124 |
+
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
|
| 125 |
+
else:
|
| 126 |
+
self.checkpointer: Checkpointer = instantiate(
|
| 127 |
+
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
|
| 128 |
+
)
|
| 129 |
+
# Initialize the timer for speed benchmarking.
|
| 130 |
+
self.training_timer = misc.TrainingTimer()
|
| 131 |
+
# Send a TimeoutError if a training step takes over timeout_period seconds.
|
| 132 |
+
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
|
| 133 |
+
|
| 134 |
+
def train(
|
| 135 |
+
self,
|
| 136 |
+
model: ImaginaireModel,
|
| 137 |
+
dataloader_train: torch.utils.data.DataLoader,
|
| 138 |
+
dataloader_val: torch.utils.data.DataLoader,
|
| 139 |
+
) -> None:
|
| 140 |
+
"""The training function.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
model (ImaginaireModel): The PyTorch model.
|
| 144 |
+
dataloader_train (torch.utils.data.DataLoader): The training data loader.
|
| 145 |
+
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 146 |
+
"""
|
| 147 |
+
# Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
|
| 148 |
+
model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
|
| 149 |
+
model.on_train_start(self.config.trainer.memory_format)
|
| 150 |
+
|
| 151 |
+
# Initialize the optimizer, scheduler, and grad_scaler.
|
| 152 |
+
self.callbacks.on_optimizer_init_start()
|
| 153 |
+
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
|
| 154 |
+
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
|
| 155 |
+
self.callbacks.on_optimizer_init_end()
|
| 156 |
+
# Load the model checkpoint and get the starting iteration number.
|
| 157 |
+
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
|
| 158 |
+
grad_accum_iter = 0
|
| 159 |
+
log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 160 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 161 |
+
# Create a DDP model wrapper.
|
| 162 |
+
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
|
| 163 |
+
elif self.config.trainer.distributed_parallelism == "fsdp":
|
| 164 |
+
model_ddp = model
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 167 |
+
log.info("Starting training...")
|
| 168 |
+
self.callbacks.on_train_start(model, iteration=iteration)
|
| 169 |
+
# Initial validation.
|
| 170 |
+
if self.config.trainer.run_validation and iteration == 0:
|
| 171 |
+
self.validate(model, dataloader_val, iteration=iteration)
|
| 172 |
+
log.info("Initial validation done.")
|
| 173 |
+
_end_training = False
|
| 174 |
+
with (
|
| 175 |
+
maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
|
| 176 |
+
maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
|
| 177 |
+
):
|
| 178 |
+
while True:
|
| 179 |
+
dataloader_train_iter = iter(dataloader_train)
|
| 180 |
+
while True:
|
| 181 |
+
self.callbacks.on_before_dataloading(iteration)
|
| 182 |
+
try:
|
| 183 |
+
with self.training_timer("dataloader_train"):
|
| 184 |
+
data_batch = next(dataloader_train_iter)
|
| 185 |
+
except StopIteration:
|
| 186 |
+
break
|
| 187 |
+
finally:
|
| 188 |
+
self.callbacks.on_after_dataloading(iteration)
|
| 189 |
+
# If max_iter is reached, exit the training loop.
|
| 190 |
+
if iteration >= self.config.trainer.max_iter:
|
| 191 |
+
_end_training = True
|
| 192 |
+
break
|
| 193 |
+
# Move all tensors in the data batch to GPU device.
|
| 194 |
+
data_batch = misc.to(data_batch, device="cuda")
|
| 195 |
+
# The actual training step.
|
| 196 |
+
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
|
| 197 |
+
self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
|
| 198 |
+
if not model.training:
|
| 199 |
+
model_ddp.train()
|
| 200 |
+
assert model_ddp.training, "model_ddp is not in training mode."
|
| 201 |
+
assert model.training, "model is not in training mode."
|
| 202 |
+
output_batch, loss, grad_accum_iter = self.training_step(
|
| 203 |
+
model_ddp,
|
| 204 |
+
optimizer,
|
| 205 |
+
scheduler,
|
| 206 |
+
grad_scaler,
|
| 207 |
+
data_batch,
|
| 208 |
+
iteration=iteration,
|
| 209 |
+
grad_accum_iter=grad_accum_iter,
|
| 210 |
+
)
|
| 211 |
+
self.callbacks.on_training_step_batch_end(
|
| 212 |
+
model, data_batch, output_batch, loss, iteration=iteration
|
| 213 |
+
)
|
| 214 |
+
# If the gradients are still being accumulated, continue to load the next training batch.
|
| 215 |
+
if grad_accum_iter != 0:
|
| 216 |
+
continue
|
| 217 |
+
# Do the following when an actual optimizer (update) step has been made.
|
| 218 |
+
iteration += 1
|
| 219 |
+
# Save checkpoint.
|
| 220 |
+
if iteration % self.config.checkpoint.save_iter == 0:
|
| 221 |
+
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 222 |
+
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 223 |
+
# Validation.
|
| 224 |
+
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
|
| 225 |
+
self.validate(model, dataloader_val, iteration=iteration)
|
| 226 |
+
# This iteration is successful; reset the timeout signal.
|
| 227 |
+
signal.alarm(self.config.trainer.timeout_period)
|
| 228 |
+
if torch_profiler:
|
| 229 |
+
torch_profiler.step()
|
| 230 |
+
if memory_profiler:
|
| 231 |
+
memory_profiler.step()
|
| 232 |
+
if _end_training:
|
| 233 |
+
break
|
| 234 |
+
log.success("Done with training.")
|
| 235 |
+
if iteration % self.config.checkpoint.save_iter != 0:
|
| 236 |
+
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 237 |
+
self.callbacks.on_train_end(model, iteration=iteration)
|
| 238 |
+
self.checkpointer.finalize()
|
| 239 |
+
distributed.barrier()
|
| 240 |
+
self.callbacks.on_app_end()
|
| 241 |
+
|
| 242 |
+
def training_step(
|
| 243 |
+
self,
|
| 244 |
+
model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
|
| 245 |
+
optimizer: torch.optim.Optimizer,
|
| 246 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 247 |
+
grad_scaler: torch.amp.GradScaler,
|
| 248 |
+
data: dict[str, torch.Tensor],
|
| 249 |
+
iteration: int = 0,
|
| 250 |
+
grad_accum_iter: int = 0,
|
| 251 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
|
| 252 |
+
"""The training step.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
|
| 256 |
+
module, depending on whether distributed training is enabled or not.
|
| 257 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 258 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 259 |
+
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 260 |
+
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 261 |
+
iteration (int): Current iteration number.
|
| 262 |
+
grad_accum_iter (int): Number of gradient accumulation iterations.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
|
| 266 |
+
loss (torch.Tensor): The total loss of the training data batch.
|
| 267 |
+
"""
|
| 268 |
+
# Only let DDP sync gradient at the last iteration of the gradient accumulation window
|
| 269 |
+
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
|
| 270 |
+
self.callbacks.on_before_forward(iteration=iteration)
|
| 271 |
+
with self.training_timer("forward"):
|
| 272 |
+
output_batch, loss = model_ddp.training_step(data, iteration)
|
| 273 |
+
self.callbacks.on_after_forward(iteration=iteration)
|
| 274 |
+
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
|
| 275 |
+
with self.training_timer("backward"):
|
| 276 |
+
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
|
| 277 |
+
loss_scaled.backward()
|
| 278 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 279 |
+
model_ddp.module.on_after_backward()
|
| 280 |
+
else:
|
| 281 |
+
model_ddp.on_after_backward()
|
| 282 |
+
self.callbacks.on_after_backward(model_ddp, iteration=iteration)
|
| 283 |
+
grad_accum_iter += 1
|
| 284 |
+
if grad_accum_iter == self.config.trainer.grad_accum_iter:
|
| 285 |
+
with self.training_timer("optimizer_step"):
|
| 286 |
+
self.callbacks.on_before_optimizer_step(
|
| 287 |
+
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
|
| 288 |
+
)
|
| 289 |
+
grad_scaler.step(optimizer)
|
| 290 |
+
grad_scaler.update()
|
| 291 |
+
scheduler.step()
|
| 292 |
+
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
|
| 293 |
+
if self.config.trainer.distributed_parallelism == "ddp":
|
| 294 |
+
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 295 |
+
else:
|
| 296 |
+
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 297 |
+
optimizer.zero_grad(set_to_none=True)
|
| 298 |
+
grad_accum_iter = 0
|
| 299 |
+
return output_batch, loss, grad_accum_iter
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
|
| 303 |
+
"""Validate on the full validation dataset.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
model (ImaginaireModel): The PyTorch model.
|
| 307 |
+
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 308 |
+
iteration (int): Current iteration number.
|
| 309 |
+
"""
|
| 310 |
+
log.info(f"Validating at iteration {iteration}...")
|
| 311 |
+
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
|
| 312 |
+
model.eval()
|
| 313 |
+
# Evaluate on the full validation set.
|
| 314 |
+
with model.pipe.ema_scope(context="Validation", is_cpu=False):
|
| 315 |
+
for val_iter, data_batch in enumerate(dataloader_val):
|
| 316 |
+
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
|
| 317 |
+
break
|
| 318 |
+
data_batch = misc.to(data_batch, device="cuda")
|
| 319 |
+
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
|
| 320 |
+
output_batch, loss = model.validation_step(data_batch, iteration)
|
| 321 |
+
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 322 |
+
self.callbacks.on_validation_end(model, iteration=iteration)
|
imaginaire/utils/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
imaginaire/utils/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
imaginaire/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (151 Bytes). View file
|
|
|
imaginaire/utils/__pycache__/device.cpython-310.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
imaginaire/utils/__pycache__/distributed.cpython-310.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
imaginaire/utils/__pycache__/io.cpython-310.pyc
ADDED
|
Binary file (4.93 kB). View file
|
|
|
imaginaire/utils/__pycache__/io.cpython-39.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
imaginaire/utils/__pycache__/log.cpython-310.pyc
ADDED
|
Binary file (4.01 kB). View file
|
|
|
imaginaire/utils/__pycache__/log.cpython-39.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
imaginaire/utils/__pycache__/misc.cpython-310.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
imaginaire/utils/callback.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
import warnings
|
| 20 |
+
from collections.abc import Callable
|
| 21 |
+
from typing import TYPE_CHECKING, Any
|
| 22 |
+
|
| 23 |
+
import omegaconf
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.data
|
| 26 |
+
import tqdm
|
| 27 |
+
|
| 28 |
+
from imaginaire.lazy_config import instantiate
|
| 29 |
+
from imaginaire.utils import distributed, log
|
| 30 |
+
from imaginaire.utils.misc import get_local_tensor_if_DTensor
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from megatron.core import parallel_state
|
| 34 |
+
except ImportError:
|
| 35 |
+
parallel_state = None
|
| 36 |
+
print("Megatron-core is not installed.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING:
|
| 40 |
+
from imaginaire.config import Config
|
| 41 |
+
from imaginaire.model import ImaginaireModel
|
| 42 |
+
from imaginaire.trainer import ImaginaireTrainer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CallBackGroup:
|
| 46 |
+
"""A class for hosting a collection of callback objects.
|
| 47 |
+
|
| 48 |
+
It is used to execute callback functions of multiple callback objects with the same method name.
|
| 49 |
+
When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
|
| 50 |
+
self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
|
| 51 |
+
|
| 52 |
+
Attributes:
|
| 53 |
+
_callbacks (list[Callback]): List of callback objects.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
|
| 57 |
+
"""Initializes the list of callback objects.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
config (Config): The config object for the Imaginaire codebase.
|
| 61 |
+
trainer (ImaginaireTrainer): The main trainer.
|
| 62 |
+
"""
|
| 63 |
+
self._callbacks = []
|
| 64 |
+
callback_configs = config.trainer.callbacks
|
| 65 |
+
if callback_configs:
|
| 66 |
+
if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
|
| 69 |
+
"Please update your code",
|
| 70 |
+
DeprecationWarning,
|
| 71 |
+
stacklevel=2,
|
| 72 |
+
)
|
| 73 |
+
callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
|
| 74 |
+
for callback_name, current_callback_cfg in callback_configs.items():
|
| 75 |
+
if "_target_" not in current_callback_cfg:
|
| 76 |
+
log.critical(
|
| 77 |
+
f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
|
| 78 |
+
)
|
| 79 |
+
continue
|
| 80 |
+
log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
|
| 81 |
+
_callback = instantiate(current_callback_cfg)
|
| 82 |
+
assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
|
| 83 |
+
_callback.config = config
|
| 84 |
+
_callback.trainer = trainer
|
| 85 |
+
self._callbacks.append(_callback)
|
| 86 |
+
|
| 87 |
+
def __getattr__(self, method_name: str) -> Callable:
|
| 88 |
+
"""Loops through the callback objects to call the corresponding callback function.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
method_name (str): Callback method name.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def multi_callback_wrapper(*args, **kwargs) -> None:
|
| 95 |
+
for callback in self._callbacks:
|
| 96 |
+
assert hasattr(callback, method_name)
|
| 97 |
+
method = getattr(callback, method_name)
|
| 98 |
+
assert callable(method)
|
| 99 |
+
_ = method(*args, **kwargs)
|
| 100 |
+
|
| 101 |
+
return multi_callback_wrapper
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Callback:
|
| 105 |
+
"""The base class for all callbacks.
|
| 106 |
+
|
| 107 |
+
All callbacks should inherit from this class and adhere to the established method names and signatures.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(self, config: Config | None = None, trainer: ImaginaireTrainer | None = None):
|
| 111 |
+
"""Initializes a Callback object.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
|
| 115 |
+
trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
|
| 116 |
+
|
| 117 |
+
Notes:
|
| 118 |
+
The config and trainer parameters are optional to maintain backward compatibility.
|
| 119 |
+
In future releases, these parameters will be removed. Upon using these parameters, a deprecation
|
| 120 |
+
warning will be issued.
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
if config is not None or trainer is not None:
|
| 124 |
+
warnings.warn(
|
| 125 |
+
"The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
|
| 126 |
+
"Please update your code to create Callback instances without these parameters.",
|
| 127 |
+
DeprecationWarning,
|
| 128 |
+
stacklevel=2,
|
| 129 |
+
)
|
| 130 |
+
del config, trainer
|
| 131 |
+
|
| 132 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Called before the training step, for each batch. This is paired with on_training_step_end() but note that
|
| 138 |
+
when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
|
| 139 |
+
this function is called for every batch.
|
| 140 |
+
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 141 |
+
for every batch, albeit with the same iteration number.
|
| 142 |
+
"""
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def on_training_step_batch_start(
|
| 146 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 147 |
+
) -> None:
|
| 148 |
+
"""
|
| 149 |
+
Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
|
| 150 |
+
on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
|
| 151 |
+
Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
|
| 152 |
+
"""
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
def on_before_forward(self, iteration: int = 0) -> None:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def on_after_forward(self, iteration: int = 0) -> None:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
def on_before_backward(
|
| 162 |
+
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 163 |
+
) -> None:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 173 |
+
pass
|
| 174 |
+
|
| 175 |
+
def on_optimizer_init_start(self) -> None:
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
def on_optimizer_init_end(self) -> None:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
def on_before_optimizer_step(
|
| 182 |
+
self,
|
| 183 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 184 |
+
optimizer: torch.optim.Optimizer,
|
| 185 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 186 |
+
grad_scaler: torch.amp.GradScaler,
|
| 187 |
+
iteration: int = 0,
|
| 188 |
+
) -> None:
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
def on_before_zero_grad(
|
| 192 |
+
self,
|
| 193 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 194 |
+
optimizer: torch.optim.Optimizer,
|
| 195 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 196 |
+
iteration: int = 0,
|
| 197 |
+
) -> None:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
def on_training_step_batch_end(
|
| 201 |
+
self,
|
| 202 |
+
model: ImaginaireModel,
|
| 203 |
+
data_batch: dict[str, torch.Tensor],
|
| 204 |
+
output_batch: dict[str, torch.Tensor],
|
| 205 |
+
loss: torch.Tensor,
|
| 206 |
+
iteration: int = 0,
|
| 207 |
+
) -> None:
|
| 208 |
+
"""
|
| 209 |
+
Called at the end of a training step for every batch even when using gradient accumulation.
|
| 210 |
+
This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
|
| 211 |
+
and therefore it may be the same for multiple batches.
|
| 212 |
+
"""
|
| 213 |
+
pass
|
| 214 |
+
|
| 215 |
+
def on_training_step_end(
|
| 216 |
+
self,
|
| 217 |
+
model: ImaginaireModel,
|
| 218 |
+
data_batch: dict[str, torch.Tensor],
|
| 219 |
+
output_batch: dict[str, torch.Tensor],
|
| 220 |
+
loss: torch.Tensor,
|
| 221 |
+
iteration: int = 0,
|
| 222 |
+
) -> None:
|
| 223 |
+
"""
|
| 224 |
+
Called at the end of a training step, but note that when using gradient accumulation, this is only called
|
| 225 |
+
when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
|
| 226 |
+
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 227 |
+
for every batch.
|
| 228 |
+
"""
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
def on_validation_start(
|
| 232 |
+
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 233 |
+
) -> None:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
def on_validation_step_start(
|
| 237 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 238 |
+
) -> None:
|
| 239 |
+
pass
|
| 240 |
+
|
| 241 |
+
def on_validation_step_end(
|
| 242 |
+
self,
|
| 243 |
+
model: ImaginaireModel,
|
| 244 |
+
data_batch: dict[str, torch.Tensor],
|
| 245 |
+
output_batch: dict[str, torch.Tensor],
|
| 246 |
+
loss: torch.Tensor,
|
| 247 |
+
iteration: int = 0,
|
| 248 |
+
) -> None:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 252 |
+
pass
|
| 253 |
+
|
| 254 |
+
def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
def on_load_checkpoint_end(
|
| 258 |
+
self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: str | None = None
|
| 259 |
+
) -> None:
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 266 |
+
"""
|
| 267 |
+
Called when checkpoint saving is about to start.
|
| 268 |
+
"""
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 272 |
+
"""
|
| 273 |
+
Called when the synchronous part of checkpointing is finished, this function can be used
|
| 274 |
+
along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
|
| 275 |
+
Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
|
| 276 |
+
does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
|
| 277 |
+
for that.
|
| 278 |
+
"""
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
|
| 282 |
+
"""
|
| 283 |
+
Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
|
| 284 |
+
For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
|
| 285 |
+
checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
|
| 286 |
+
checkpointing, this function is called as soon as the notification is received from the checkpointer process,
|
| 287 |
+
which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
|
| 288 |
+
the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
|
| 289 |
+
as this would be a significant overestimate.
|
| 290 |
+
"""
|
| 291 |
+
pass
|
| 292 |
+
|
| 293 |
+
def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 294 |
+
pass
|
| 295 |
+
|
| 296 |
+
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
def on_app_end(self) -> None:
|
| 300 |
+
pass
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class EMAModelCallback(Callback):
|
| 304 |
+
"""The callback class for tracking EMA model weights."""
|
| 305 |
+
|
| 306 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 307 |
+
# Set up the EMA model weight tracker.
|
| 308 |
+
if model.config.ema.enabled:
|
| 309 |
+
assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
|
| 310 |
+
# EMA model must be kept in FP32 precision.
|
| 311 |
+
model.ema = model.ema.to(dtype=torch.float32)
|
| 312 |
+
else:
|
| 313 |
+
assert not hasattr(model, "ema"), "There should be no EMA initialized."
|
| 314 |
+
|
| 315 |
+
def on_training_step_end(
|
| 316 |
+
self,
|
| 317 |
+
model: ImaginaireModel,
|
| 318 |
+
data_batch: dict[str, torch.Tensor],
|
| 319 |
+
output_batch: dict[str, torch.Tensor],
|
| 320 |
+
loss: torch.Tensor,
|
| 321 |
+
iteration: int = 0,
|
| 322 |
+
) -> None:
|
| 323 |
+
# Update the EMA model with the new regular weights.
|
| 324 |
+
if model.config.ema.enabled:
|
| 325 |
+
model.ema.update_average(model, iteration)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ProgressBarCallback(Callback):
|
| 329 |
+
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 330 |
+
|
| 331 |
+
@distributed.rank0_only
|
| 332 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 333 |
+
self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 334 |
+
|
| 335 |
+
@distributed.rank0_only
|
| 336 |
+
def on_training_step_end(
|
| 337 |
+
self,
|
| 338 |
+
model: ImaginaireModel,
|
| 339 |
+
data_batch: dict[str, torch.Tensor],
|
| 340 |
+
output_batch: dict[str, torch.Tensor],
|
| 341 |
+
loss: torch.Tensor,
|
| 342 |
+
iteration: int = 0,
|
| 343 |
+
) -> None:
|
| 344 |
+
self.train_pbar.update()
|
| 345 |
+
|
| 346 |
+
@distributed.rank0_only
|
| 347 |
+
def on_validation_start(
|
| 348 |
+
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 349 |
+
) -> None:
|
| 350 |
+
if self.config.trainer.max_val_iter is not None:
|
| 351 |
+
num_iter = self.config.trainer.max_val_iter
|
| 352 |
+
else:
|
| 353 |
+
num_iter = len(dataloader_val)
|
| 354 |
+
assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
|
| 355 |
+
self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
|
| 356 |
+
|
| 357 |
+
@distributed.rank0_only
|
| 358 |
+
def on_validation_step_end(
|
| 359 |
+
self,
|
| 360 |
+
model: ImaginaireModel,
|
| 361 |
+
data_batch: dict[str, torch.Tensor],
|
| 362 |
+
output_batch: dict[str, torch.Tensor],
|
| 363 |
+
loss: torch.Tensor,
|
| 364 |
+
iteration: int = 0,
|
| 365 |
+
) -> None:
|
| 366 |
+
self.val_pbar.update()
|
| 367 |
+
|
| 368 |
+
@distributed.rank0_only
|
| 369 |
+
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 370 |
+
self.val_pbar.close()
|
| 371 |
+
|
| 372 |
+
@distributed.rank0_only
|
| 373 |
+
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 374 |
+
self.trainer.checkpointer.finalize()
|
| 375 |
+
self.train_pbar.close()
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class IterationLoggerCallback(Callback):
|
| 379 |
+
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 380 |
+
|
| 381 |
+
@distributed.rank0_only
|
| 382 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 383 |
+
# self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 384 |
+
self.start_iteration_time = time.time()
|
| 385 |
+
self.elapsed_iteration_time = 0
|
| 386 |
+
|
| 387 |
+
@distributed.rank0_only
|
| 388 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 389 |
+
self.start_iteration_time = time.time()
|
| 390 |
+
|
| 391 |
+
@distributed.rank0_only
|
| 392 |
+
def on_training_step_end(
|
| 393 |
+
self,
|
| 394 |
+
model: ImaginaireModel,
|
| 395 |
+
data_batch: dict[str, torch.Tensor],
|
| 396 |
+
output_batch: dict[str, torch.Tensor],
|
| 397 |
+
loss: torch.Tensor,
|
| 398 |
+
iteration: int = 0,
|
| 399 |
+
) -> None:
|
| 400 |
+
self.elapsed_iteration_time += time.time() - self.start_iteration_time
|
| 401 |
+
|
| 402 |
+
if iteration % self.config.trainer.logging_iter == 0:
|
| 403 |
+
avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
|
| 404 |
+
log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
|
| 405 |
+
|
| 406 |
+
self.elapsed_iteration_time = 0
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class LowPrecisionCallback(Callback):
|
| 410 |
+
"""The callback class handling low precision training
|
| 411 |
+
|
| 412 |
+
Config with non-primitive type makes it difficult to override the option.
|
| 413 |
+
The callback gets precision from model.precision instead.
|
| 414 |
+
It also auto disabled when using fp32.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
|
| 418 |
+
self.update_iter = update_iter
|
| 419 |
+
|
| 420 |
+
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 421 |
+
assert model.precision in [
|
| 422 |
+
torch.bfloat16,
|
| 423 |
+
torch.float16,
|
| 424 |
+
torch.half,
|
| 425 |
+
], "LowPrecisionCallback must use a low precision dtype."
|
| 426 |
+
self.precision_type = model.precision
|
| 427 |
+
|
| 428 |
+
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 429 |
+
for k, v in data.items():
|
| 430 |
+
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 431 |
+
data[k] = v.to(dtype=self.precision_type)
|
| 432 |
+
|
| 433 |
+
def on_validation_step_start(
|
| 434 |
+
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 435 |
+
) -> None:
|
| 436 |
+
for k, v in data.items():
|
| 437 |
+
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 438 |
+
data[k] = v.to(dtype=self.precision_type)
|
| 439 |
+
|
| 440 |
+
def on_before_zero_grad(
|
| 441 |
+
self,
|
| 442 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 443 |
+
optimizer: torch.optim.Optimizer,
|
| 444 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 445 |
+
iteration: int = 0,
|
| 446 |
+
) -> None:
|
| 447 |
+
if iteration % self.update_iter == 0:
|
| 448 |
+
if getattr(optimizer, "master_weights", False):
|
| 449 |
+
params, master_params = [], []
|
| 450 |
+
for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master, strict=False):
|
| 451 |
+
for p, p_master in zip(group["params"], group_master["params"], strict=False):
|
| 452 |
+
params.append(get_local_tensor_if_DTensor(p.data))
|
| 453 |
+
master_params.append(p_master.data)
|
| 454 |
+
torch._foreach_copy_(params, master_params)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class NVTXCallback(Callback):
|
| 458 |
+
"""The callback for creating NVTX ranges"""
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
synchronize: bool = False,
|
| 463 |
+
config: Config | None = None,
|
| 464 |
+
trainer: ImaginaireTrainer | None = None,
|
| 465 |
+
):
|
| 466 |
+
super().__init__(config, trainer)
|
| 467 |
+
self.synchronize = synchronize
|
| 468 |
+
|
| 469 |
+
def on_before_forward(self, iteration: int = 0) -> None:
|
| 470 |
+
if self.synchronize:
|
| 471 |
+
torch.cuda.synchronize()
|
| 472 |
+
torch.cuda.nvtx.range_push("forward")
|
| 473 |
+
|
| 474 |
+
def on_after_forward(self, iteration: int = 0) -> None:
|
| 475 |
+
if self.synchronize:
|
| 476 |
+
torch.cuda.synchronize()
|
| 477 |
+
torch.cuda.nvtx.range_pop()
|
| 478 |
+
|
| 479 |
+
def on_before_backward(
|
| 480 |
+
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 481 |
+
) -> None:
|
| 482 |
+
if self.synchronize:
|
| 483 |
+
torch.cuda.synchronize()
|
| 484 |
+
torch.cuda.nvtx.range_push("backward")
|
| 485 |
+
|
| 486 |
+
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 487 |
+
if self.synchronize:
|
| 488 |
+
torch.cuda.synchronize()
|
| 489 |
+
torch.cuda.nvtx.range_pop()
|
| 490 |
+
|
| 491 |
+
def on_before_optimizer_step(
|
| 492 |
+
self,
|
| 493 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 494 |
+
optimizer: torch.optim.Optimizer,
|
| 495 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 496 |
+
grad_scaler: torch.amp.GradScaler,
|
| 497 |
+
iteration: int = 0,
|
| 498 |
+
) -> None:
|
| 499 |
+
if self.synchronize:
|
| 500 |
+
torch.cuda.synchronize()
|
| 501 |
+
torch.cuda.nvtx.range_push("optimizer_step")
|
| 502 |
+
|
| 503 |
+
def on_before_zero_grad(
|
| 504 |
+
self,
|
| 505 |
+
model_ddp: distributed.DistributedDataParallel,
|
| 506 |
+
optimizer: torch.optim.Optimizer,
|
| 507 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 508 |
+
iteration: int = 0,
|
| 509 |
+
) -> None:
|
| 510 |
+
if self.synchronize:
|
| 511 |
+
torch.cuda.synchronize()
|
| 512 |
+
torch.cuda.nvtx.range_pop()
|
| 513 |
+
|
| 514 |
+
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 515 |
+
torch.cuda.nvtx.range_push("dataloading")
|
| 516 |
+
|
| 517 |
+
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 518 |
+
torch.cuda.nvtx.range_pop()
|
imaginaire/utils/checkpointer.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import threading
|
| 20 |
+
from typing import TYPE_CHECKING, NamedTuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from imaginaire.model import ImaginaireModel
|
| 27 |
+
from imaginaire.utils import callback, distributed, log, misc
|
| 28 |
+
from imaginaire.utils.parallelism import ModelWrapper
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING:
|
| 31 |
+
from imaginaire.config import CheckpointConfig, JobConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Checkpointer:
|
| 35 |
+
"""The checkpointer class. Supports checkpoint saving/loading to local disk."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
|
| 38 |
+
"""Constructor of the checkpointer.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
config_checkpoint (CheckpointConfig): The config object for the checkpointer.
|
| 42 |
+
"""
|
| 43 |
+
# Set the callback functions.
|
| 44 |
+
self.callbacks = callbacks
|
| 45 |
+
self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
|
| 46 |
+
self.strict_resume = config_checkpoint.strict_resume
|
| 47 |
+
self.load_path = config_checkpoint.load_path or None
|
| 48 |
+
self.load_training_state = config_checkpoint.load_training_state
|
| 49 |
+
self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
|
| 50 |
+
self.save_thread = None
|
| 51 |
+
|
| 52 |
+
def save(
|
| 53 |
+
self,
|
| 54 |
+
model: ImaginaireModel,
|
| 55 |
+
optimizer: torch.optim.Optimizer,
|
| 56 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 57 |
+
grad_scaler: torch.amp.GradScaler,
|
| 58 |
+
iteration: int,
|
| 59 |
+
) -> None:
|
| 60 |
+
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model (ImaginaireModel): The PyTorch model.
|
| 64 |
+
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 65 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 66 |
+
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 67 |
+
iteration (int): Current iteration number.
|
| 68 |
+
"""
|
| 69 |
+
self.callbacks.on_save_checkpoint_start(model, iteration)
|
| 70 |
+
|
| 71 |
+
checkpoint_file = f"iter_{iteration:09}.pt"
|
| 72 |
+
|
| 73 |
+
if distributed.get_rank() == 0:
|
| 74 |
+
state_dict = dict(
|
| 75 |
+
model=model.state_dict(),
|
| 76 |
+
optimizer=optimizer.state_dict(),
|
| 77 |
+
scheduler=scheduler.state_dict(),
|
| 78 |
+
grad_scaler=grad_scaler.state_dict(),
|
| 79 |
+
iteration=iteration,
|
| 80 |
+
)
|
| 81 |
+
state_dict = misc.to(state_dict, device="cpu")
|
| 82 |
+
self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
|
| 83 |
+
# Wait for previous saver thread to end.
|
| 84 |
+
if self.save_thread:
|
| 85 |
+
self.save_thread.join()
|
| 86 |
+
# Run the checkpoint saver in a separate thread.
|
| 87 |
+
self.save_thread = threading.Thread(
|
| 88 |
+
target=self._save_worker_local,
|
| 89 |
+
daemon=False,
|
| 90 |
+
args=(state_dict, checkpoint_file, distributed.get_rank()),
|
| 91 |
+
)
|
| 92 |
+
self.save_thread.start()
|
| 93 |
+
|
| 94 |
+
# Note: Checkpoints are saved on a separate thread and this callback is not accurate.
|
| 95 |
+
# Please check logs from on_save_checkpoint_success() for better accuracy
|
| 96 |
+
self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
|
| 97 |
+
|
| 98 |
+
@misc.timer("checkpoint saving (local)")
|
| 99 |
+
def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
|
| 100 |
+
"""Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
|
| 104 |
+
checkpoint_file (str): The file name of the model checkpoint.
|
| 105 |
+
rank (int): GPU device (default: 0).
|
| 106 |
+
"""
|
| 107 |
+
checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
|
| 108 |
+
os.makedirs(self.checkpoint_dir_local, exist_ok=True)
|
| 109 |
+
try:
|
| 110 |
+
torch.save(state_dict, checkpoint_path)
|
| 111 |
+
if rank == 0:
|
| 112 |
+
self._write_latest_checkpoint_file(checkpoint_file)
|
| 113 |
+
log.success(f"Saved checkpoint (local): {checkpoint_path}")
|
| 114 |
+
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
|
| 115 |
+
self.callbacks.on_save_checkpoint_success(iteration=iteration)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
log.exception(f"Checkpoint failed to save (local): {e}")
|
| 118 |
+
|
| 119 |
+
@misc.timer("checkpoint loading")
|
| 120 |
+
def load(
|
| 121 |
+
self,
|
| 122 |
+
model: ImaginaireModel,
|
| 123 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 124 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
|
| 125 |
+
grad_scaler: torch.amp.GradScaler | None = None,
|
| 126 |
+
) -> int:
|
| 127 |
+
"""Load network weights and optimizer states from a checkpoint in a single process.
|
| 128 |
+
|
| 129 |
+
The priority of the checkpoint loading logic is:
|
| 130 |
+
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
|
| 131 |
+
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
|
| 132 |
+
- This is typically used for inference mode.
|
| 133 |
+
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
|
| 134 |
+
3. If none of the above, randomly initialize the model parameters and train from scratch.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
model (ImaginaireModel): The PyTorch model.
|
| 138 |
+
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
|
| 139 |
+
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
|
| 140 |
+
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
iteration (int): the iteration number to start/resume from.
|
| 144 |
+
"""
|
| 145 |
+
self.callbacks.on_load_checkpoint_start(model)
|
| 146 |
+
|
| 147 |
+
latest_checkpoint_file = self._read_latest_checkpoint_file()
|
| 148 |
+
if latest_checkpoint_file is not None:
|
| 149 |
+
# 1. Resume training from latest_checkpoint.txt under the same name.
|
| 150 |
+
checkpoint_dir = self.checkpoint_dir_local
|
| 151 |
+
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
|
| 152 |
+
resume = True
|
| 153 |
+
only_resume_scheduler = True
|
| 154 |
+
else:
|
| 155 |
+
if self.load_path:
|
| 156 |
+
# 2. Load the module weights specified by config_checkpoint.path.
|
| 157 |
+
checkpoint_path = self.load_path
|
| 158 |
+
resume = self.load_training_state
|
| 159 |
+
only_resume_scheduler = self.only_load_scheduler_state
|
| 160 |
+
else:
|
| 161 |
+
# 3. Randomly initialize the model parameters and train from scratch.
|
| 162 |
+
checkpoint_path = None
|
| 163 |
+
resume = False
|
| 164 |
+
only_resume_scheduler = False
|
| 165 |
+
# Load checkpoint.
|
| 166 |
+
if checkpoint_path is not None:
|
| 167 |
+
self._check_checkpoint_exists(checkpoint_path)
|
| 168 |
+
log.info(f"Loading checkpoint (local): {checkpoint_path}")
|
| 169 |
+
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
| 170 |
+
log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
|
| 171 |
+
self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
|
| 172 |
+
# Load the state dicts.
|
| 173 |
+
log.info("- Loading the model...")
|
| 174 |
+
model.load_state_dict(state_dict["model"], strict=self.strict_resume)
|
| 175 |
+
if resume or only_resume_scheduler:
|
| 176 |
+
iteration = state_dict["iteration"]
|
| 177 |
+
assert scheduler
|
| 178 |
+
log.info("- Loading the scheduler...")
|
| 179 |
+
scheduler.load_state_dict(state_dict["scheduler"])
|
| 180 |
+
scheduler.last_epoch = iteration
|
| 181 |
+
else:
|
| 182 |
+
iteration = 0
|
| 183 |
+
if resume:
|
| 184 |
+
assert optimizer
|
| 185 |
+
log.info("- Loading the optimizer...")
|
| 186 |
+
optimizer.load_state_dict(state_dict["optimizer"])
|
| 187 |
+
log.info("- Loading the gradient scaler...")
|
| 188 |
+
grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
| 189 |
+
log.success(f"Done with loading the checkpoint (iteration {iteration}).")
|
| 190 |
+
else:
|
| 191 |
+
log.success("Done with loading the checkpoint.")
|
| 192 |
+
else:
|
| 193 |
+
# Checkpoint not found and not specified. We will train everything from scratch.
|
| 194 |
+
iteration = 0
|
| 195 |
+
log.info("Training from scratch.")
|
| 196 |
+
torch.cuda.empty_cache()
|
| 197 |
+
|
| 198 |
+
self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
|
| 199 |
+
|
| 200 |
+
return iteration
|
| 201 |
+
|
| 202 |
+
def _read_latest_checkpoint_file(self) -> str | None:
|
| 203 |
+
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
checkpoint_file (str | None): file name of the latest saved checkpoint.
|
| 207 |
+
"""
|
| 208 |
+
checkpoint_file = None
|
| 209 |
+
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 210 |
+
if os.path.isfile(latest_path):
|
| 211 |
+
checkpoint_file = open(latest_path).read().strip()
|
| 212 |
+
return checkpoint_file
|
| 213 |
+
|
| 214 |
+
def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
|
| 215 |
+
"""Track the file name of the latest saved checkpoint.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
checkpoint_file (str): file name of the latest saved checkpoint.
|
| 219 |
+
"""
|
| 220 |
+
content = f"{checkpoint_file}\n"
|
| 221 |
+
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 222 |
+
with open(latest_path, "w") as file:
|
| 223 |
+
file.write(content)
|
| 224 |
+
|
| 225 |
+
def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
|
| 226 |
+
"""If the file checkpoint_path does not exist, raise an error.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
checkpoint_path (str): full path to the checkpoint.
|
| 230 |
+
"""
|
| 231 |
+
if not os.path.exists(checkpoint_path):
|
| 232 |
+
raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
|
| 233 |
+
|
| 234 |
+
def finalize(self) -> None:
|
| 235 |
+
"""Finalize the checkpointer."""
|
| 236 |
+
if self.save_thread:
|
| 237 |
+
self.save_thread.join()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class _IncompatibleKeys(
|
| 241 |
+
NamedTuple(
|
| 242 |
+
"IncompatibleKeys",
|
| 243 |
+
[
|
| 244 |
+
("missing_keys", list[str]),
|
| 245 |
+
("unexpected_keys", list[str]),
|
| 246 |
+
("incorrect_shapes", list[tuple[str, tuple[int], tuple[int]]]),
|
| 247 |
+
],
|
| 248 |
+
)
|
| 249 |
+
):
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def load_checkpoint(
|
| 254 |
+
model_parts: list[nn.Module],
|
| 255 |
+
ckpt_dir,
|
| 256 |
+
model_ckpt_key_map: dict[str, str] = {}, # noqa: B006
|
| 257 |
+
):
|
| 258 |
+
log.info(f"Loading checkpoint from {ckpt_dir}.")
|
| 259 |
+
|
| 260 |
+
_model_wrapper = ModelWrapper(model_parts)
|
| 261 |
+
state_dict = _model_wrapper.state_dict()
|
| 262 |
+
# remove _extra_state
|
| 263 |
+
state_dict = {k: v for k, v in state_dict.items() if not k.endswith("._extra_state")}
|
| 264 |
+
|
| 265 |
+
# remap keys if needed
|
| 266 |
+
if model_ckpt_key_map:
|
| 267 |
+
for model_key, checkpoint_key in model_ckpt_key_map.items():
|
| 268 |
+
state_dict[checkpoint_key] = state_dict.pop(model_key)
|
| 269 |
+
log.info(f"Re-mapping {model_key} to {checkpoint_key}")
|
| 270 |
+
|
| 271 |
+
fs_storage_reader = dist.checkpoint.FileSystemReader(ckpt_dir)
|
| 272 |
+
dist.checkpoint.load(state_dict=state_dict, storage_reader=fs_storage_reader)
|
| 273 |
+
|
| 274 |
+
# inverse the remapping if needed
|
| 275 |
+
if model_ckpt_key_map:
|
| 276 |
+
for model_key, checkpoint_key in model_ckpt_key_map.items():
|
| 277 |
+
state_dict[model_key] = state_dict.pop(checkpoint_key)
|
| 278 |
+
log.info(f"Inverse re-mapping {checkpoint_key} to {model_key}")
|
| 279 |
+
|
| 280 |
+
_model_wrapper.load_state_dict(state_dict)
|
| 281 |
+
|
| 282 |
+
log.info(f"Finished loading checkpoint from {ckpt_dir}.")
|
imaginaire/utils/config_helper.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import importlib
|
| 17 |
+
import os
|
| 18 |
+
import pkgutil
|
| 19 |
+
import sys
|
| 20 |
+
from dataclasses import fields as dataclass_fields
|
| 21 |
+
from dataclasses import is_dataclass
|
| 22 |
+
from typing import Any
|
| 23 |
+
|
| 24 |
+
import attr
|
| 25 |
+
import attrs
|
| 26 |
+
from hydra import compose, initialize
|
| 27 |
+
from hydra.core.config_store import ConfigStore
|
| 28 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 29 |
+
from omegaconf import DictConfig, OmegaConf
|
| 30 |
+
|
| 31 |
+
from imaginaire.config import Config
|
| 32 |
+
from imaginaire.utils import log
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def is_attrs_or_dataclass(obj) -> bool:
|
| 36 |
+
"""
|
| 37 |
+
Check if the object is an instance of an attrs class or a dataclass.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
obj: The object to check.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
|
| 44 |
+
"""
|
| 45 |
+
return is_dataclass(obj) or attr.has(type(obj))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_fields(obj):
|
| 49 |
+
"""
|
| 50 |
+
Get the fields of an attrs class or a dataclass.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
list: A list of field names.
|
| 57 |
+
|
| 58 |
+
Raises:
|
| 59 |
+
ValueError: If the object is neither an attrs class nor a dataclass.
|
| 60 |
+
"""
|
| 61 |
+
if is_dataclass(obj):
|
| 62 |
+
return [field.name for field in dataclass_fields(obj)]
|
| 63 |
+
elif attr.has(type(obj)):
|
| 64 |
+
return [field.name for field in attr.fields(type(obj))]
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("The object is neither an attrs class nor a dataclass.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def override(config: Config, overrides: list[str] | None = None) -> Config:
|
| 70 |
+
"""
|
| 71 |
+
:param config: the instance of class `Config` (usually from `make_config`)
|
| 72 |
+
:param overrides: list of overrides for config
|
| 73 |
+
:return: the composed instance of class `Config`
|
| 74 |
+
"""
|
| 75 |
+
# Store the class of the config for reconstruction after overriding.
|
| 76 |
+
# config_class = type(config)
|
| 77 |
+
|
| 78 |
+
# Convert Config object to a DictConfig object
|
| 79 |
+
config_dict = attrs.asdict(config)
|
| 80 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 81 |
+
# Enforce "--" separator between the script arguments and overriding configs.
|
| 82 |
+
if overrides:
|
| 83 |
+
if overrides[0] != "--":
|
| 84 |
+
raise ValueError('Hydra config overrides must be separated with a "--" token.')
|
| 85 |
+
overrides = overrides[1:]
|
| 86 |
+
# Use Hydra to handle overrides
|
| 87 |
+
cs = ConfigStore.instance()
|
| 88 |
+
cs.store(name="config", node=config_omegaconf)
|
| 89 |
+
if not GlobalHydra().is_initialized():
|
| 90 |
+
with initialize(version_base=None):
|
| 91 |
+
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 92 |
+
OmegaConf.resolve(config_omegaconf)
|
| 93 |
+
else:
|
| 94 |
+
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 95 |
+
OmegaConf.resolve(config_omegaconf)
|
| 96 |
+
|
| 97 |
+
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
|
| 98 |
+
"""
|
| 99 |
+
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
ref_instance: The reference instance to determine the type and fields when needed
|
| 103 |
+
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
AssertionError: If the fields do not match or if extra keys are found.
|
| 110 |
+
Exception: If there is an error constructing the new instance.
|
| 111 |
+
"""
|
| 112 |
+
is_type = is_attrs_or_dataclass(ref_instance)
|
| 113 |
+
if not is_type:
|
| 114 |
+
return kwargs
|
| 115 |
+
else:
|
| 116 |
+
ref_fields = set(get_fields(ref_instance))
|
| 117 |
+
assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
|
| 118 |
+
"kwargs must be a dictionary or a DictConfig"
|
| 119 |
+
)
|
| 120 |
+
keys = set(kwargs.keys())
|
| 121 |
+
|
| 122 |
+
# ref_fields must equal to or include all keys
|
| 123 |
+
extra_keys = keys - ref_fields
|
| 124 |
+
assert ref_fields == keys or keys.issubset(ref_fields), (
|
| 125 |
+
f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
resolved_kwargs: dict[str, Any] = {}
|
| 129 |
+
for f in keys:
|
| 130 |
+
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
|
| 131 |
+
try:
|
| 132 |
+
new_instance = type(ref_instance)(**resolved_kwargs)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
|
| 135 |
+
log.error(e)
|
| 136 |
+
raise e
|
| 137 |
+
return new_instance
|
| 138 |
+
|
| 139 |
+
config = config_from_dict(config, config_omegaconf)
|
| 140 |
+
|
| 141 |
+
return config
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_config_module(config_file: str) -> str:
|
| 145 |
+
if not config_file.endswith(".py"):
|
| 146 |
+
log.error("Config file cannot be specified as module.")
|
| 147 |
+
log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
|
| 148 |
+
assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
|
| 149 |
+
# Convert to importable module format.
|
| 150 |
+
config_module = config_file.replace("/", ".").replace(".py", "")
|
| 151 |
+
return config_module
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
|
| 155 |
+
"""
|
| 156 |
+
Import all modules from the specified package path recursively.
|
| 157 |
+
|
| 158 |
+
This function is typically used in conjunction with Hydra to ensure that all modules
|
| 159 |
+
within a specified package are imported, which is necessary for registering configurations.
|
| 160 |
+
|
| 161 |
+
Example usage:
|
| 162 |
+
```python
|
| 163 |
+
import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
package_path (str): The dotted path to the package from which to import all modules.
|
| 168 |
+
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
| 169 |
+
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
| 170 |
+
"""
|
| 171 |
+
log.critical(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
| 172 |
+
package = importlib.import_module(package_path)
|
| 173 |
+
package_directory = package.__path__
|
| 174 |
+
|
| 175 |
+
def import_modules_recursively(directory: str, prefix: str) -> None:
|
| 176 |
+
"""
|
| 177 |
+
Recursively imports or reloads all modules in the given directory.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
directory (str): The file system path to the current package directory.
|
| 181 |
+
prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
|
| 182 |
+
"""
|
| 183 |
+
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
|
| 184 |
+
if skip_underscore and module_name.startswith("_"):
|
| 185 |
+
log.debug(f"Skipping module {module_name} as it starts with an underscore")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
full_module_name = f"{prefix}.{module_name}"
|
| 189 |
+
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
|
| 190 |
+
|
| 191 |
+
if full_module_name in sys.modules and reload:
|
| 192 |
+
importlib.reload(sys.modules[full_module_name])
|
| 193 |
+
else:
|
| 194 |
+
importlib.import_module(full_module_name)
|
| 195 |
+
|
| 196 |
+
if is_pkg:
|
| 197 |
+
sub_package_directory = os.path.join(directory, module_name)
|
| 198 |
+
import_modules_recursively(sub_package_directory, full_module_name)
|
| 199 |
+
|
| 200 |
+
for directory in package_directory:
|
| 201 |
+
import_modules_recursively(directory, package_path)
|
imaginaire/utils/device.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import pynvml
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Device:
|
| 23 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
|
| 24 |
+
|
| 25 |
+
def __init__(self, device_idx: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
| 28 |
+
|
| 29 |
+
def get_name(self) -> str:
|
| 30 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
| 31 |
+
|
| 32 |
+
def get_cpu_affinity(self) -> list[int]:
|
| 33 |
+
affinity_string = ""
|
| 34 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
| 35 |
+
# assume nvml returns list of 64 bit ints
|
| 36 |
+
affinity_string = f"{j:064b}" + affinity_string
|
| 37 |
+
affinity_list = [int(x) for x in affinity_string]
|
| 38 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
| 39 |
+
return [i for i, e in enumerate(affinity_list) if e != 0]
|
imaginaire/utils/distributed.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import collections.abc
|
| 20 |
+
import ctypes
|
| 21 |
+
import functools
|
| 22 |
+
import os
|
| 23 |
+
from collections.abc import Callable, Container
|
| 24 |
+
from contextlib import contextmanager
|
| 25 |
+
from datetime import timedelta
|
| 26 |
+
from typing import TYPE_CHECKING, Any
|
| 27 |
+
|
| 28 |
+
import pynvml
|
| 29 |
+
import torch
|
| 30 |
+
import torch.distributed as dist
|
| 31 |
+
from torch.distributed import get_process_group_ranks
|
| 32 |
+
|
| 33 |
+
from imaginaire.utils.device import Device
|
| 34 |
+
|
| 35 |
+
if dist.is_available():
|
| 36 |
+
from torch.distributed.distributed_c10d import _get_default_group
|
| 37 |
+
from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
|
| 38 |
+
|
| 39 |
+
from imaginaire.utils import log
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING:
|
| 42 |
+
from imaginaire.config import DDPConfig
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from megatron.core import parallel_state
|
| 46 |
+
except ImportError:
|
| 47 |
+
print("Megatron-core is not installed.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def init() -> int | None:
|
| 51 |
+
"""Initialize distributed training."""
|
| 52 |
+
if dist.is_initialized():
|
| 53 |
+
return torch.cuda.current_device()
|
| 54 |
+
|
| 55 |
+
# Set GPU affinity.
|
| 56 |
+
pynvml.nvmlInit()
|
| 57 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 58 |
+
try:
|
| 59 |
+
device = Device(local_rank)
|
| 60 |
+
os.sched_setaffinity(0, device.get_cpu_affinity())
|
| 61 |
+
except (OSError, pynvml.NVMLError) as e:
|
| 62 |
+
log.warning(f"Failed to set device affinity: {e}")
|
| 63 |
+
# Set up NCCL communication.
|
| 64 |
+
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
|
| 65 |
+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
| 66 |
+
if dist.is_available():
|
| 67 |
+
torch.cuda.set_device(local_rank)
|
| 68 |
+
# Get the timeout value from environment variable
|
| 69 |
+
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
|
| 70 |
+
# Convert the timeout to an integer (if it isn't already) and then to a timedelta
|
| 71 |
+
timeout_timedelta = timedelta(seconds=int(timeout_seconds))
|
| 72 |
+
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
|
| 73 |
+
log.info(
|
| 74 |
+
f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
|
| 75 |
+
rank0_only=False,
|
| 76 |
+
)
|
| 77 |
+
# Increase the L2 fetch granularity for faster speed.
|
| 78 |
+
_libcudart = ctypes.CDLL("libcudart.so")
|
| 79 |
+
# Set device limit on the current device.
|
| 80 |
+
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
| 81 |
+
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
| 82 |
+
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
|
| 83 |
+
log.info(f"Training with {get_world_size()} GPUs.")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_rank(group: dist.ProcessGroup | None = None) -> int:
|
| 87 |
+
"""Get the rank (GPU device) of the worker.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
rank (int): The rank of the worker.
|
| 91 |
+
"""
|
| 92 |
+
rank = 0
|
| 93 |
+
if dist.is_available() and dist.is_initialized():
|
| 94 |
+
rank = dist.get_rank(group)
|
| 95 |
+
return rank
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_world_size(group: dist.ProcessGroup | None = None) -> int:
|
| 99 |
+
"""Get world size. How many GPUs are available in this job.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
world_size (int): The total number of GPUs available in this job.
|
| 103 |
+
"""
|
| 104 |
+
world_size = 1
|
| 105 |
+
if dist.is_available() and dist.is_initialized():
|
| 106 |
+
world_size = dist.get_world_size(group)
|
| 107 |
+
return world_size
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def is_rank0() -> bool:
|
| 111 |
+
"""Check if current process is the master GPU.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
(bool): True if this function is called from the master GPU, else False.
|
| 115 |
+
"""
|
| 116 |
+
return get_rank() == 0
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def is_local_rank0() -> bool:
|
| 120 |
+
"""Check if current process is the local master GPU in the current node.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
(bool): True if this function is called from the local master GPU, else False.
|
| 124 |
+
"""
|
| 125 |
+
return torch.cuda.current_device() == 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def rank0_only(func: Callable) -> Callable:
|
| 129 |
+
"""Apply this function only to the master GPU.
|
| 130 |
+
|
| 131 |
+
Example usage:
|
| 132 |
+
@rank0_only
|
| 133 |
+
def func(x):
|
| 134 |
+
return x + 3
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
func (Callable): a function.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
(Callable): A function wrapper executing the function only on the master GPU.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
@functools.wraps(func)
|
| 144 |
+
def wrapper(*args, **kwargs):
|
| 145 |
+
if is_rank0():
|
| 146 |
+
return func(*args, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
return wrapper
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def barrier() -> None:
|
| 154 |
+
"""Barrier for all GPUs."""
|
| 155 |
+
if dist.is_available() and dist.is_initialized():
|
| 156 |
+
dist.barrier()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def rank0_first(func: Callable) -> Callable:
|
| 160 |
+
"""run the function on rank 0 first, then on other ranks."""
|
| 161 |
+
|
| 162 |
+
@functools.wraps(func)
|
| 163 |
+
def wrapper(*args, **kwargs):
|
| 164 |
+
if is_rank0():
|
| 165 |
+
result = func(*args, **kwargs)
|
| 166 |
+
barrier()
|
| 167 |
+
if not is_rank0():
|
| 168 |
+
result = func(*args, **kwargs)
|
| 169 |
+
return result
|
| 170 |
+
|
| 171 |
+
return wrapper
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
|
| 175 |
+
"""Wraps the model to enable data parallalism for training across multiple GPU devices.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
config_ddp (DDPConfig): The data parallel config.
|
| 179 |
+
model (torch.nn.Module): The PyTorch module.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
|
| 183 |
+
if distributed environment is available, otherwise return the original model.
|
| 184 |
+
"""
|
| 185 |
+
if dist.is_available() and dist.is_initialized():
|
| 186 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 187 |
+
try:
|
| 188 |
+
ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
log.info(e)
|
| 191 |
+
log.info("parallel_state not initialized, treating all GPUs equally for DDP")
|
| 192 |
+
ddp_group = None
|
| 193 |
+
|
| 194 |
+
model = DistributedDataParallel(
|
| 195 |
+
model,
|
| 196 |
+
device_ids=[local_rank],
|
| 197 |
+
output_device=local_rank,
|
| 198 |
+
find_unused_parameters=config_ddp.find_unused_parameters,
|
| 199 |
+
static_graph=config_ddp.static_graph,
|
| 200 |
+
broadcast_buffers=config_ddp.broadcast_buffers,
|
| 201 |
+
process_group=ddp_group,
|
| 202 |
+
)
|
| 203 |
+
return model
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
|
| 207 |
+
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step().
|
| 208 |
+
|
| 209 |
+
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
|
| 210 |
+
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
|
| 211 |
+
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
|
| 212 |
+
training_step), allowing us to preserve the function names and signatures.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
def __init__(self, model: torch.nn.Module, *args, **kwargs):
|
| 216 |
+
super().__init__(model, *args, **kwargs)
|
| 217 |
+
self.show_sync_grad_static_graph_warning = True
|
| 218 |
+
|
| 219 |
+
def training_step(self, *args, **kwargs) -> Any:
|
| 220 |
+
# Cache the original model.forward() method.
|
| 221 |
+
original_forward = self.module.forward
|
| 222 |
+
|
| 223 |
+
def wrapped_training_step(*_args, **_kwargs):
|
| 224 |
+
# Unpatch immediately before calling training_step() because itself may want to call the real forward.
|
| 225 |
+
self.module.forward = original_forward
|
| 226 |
+
# The actual .training_step().
|
| 227 |
+
return self.module.training_step(*_args, **_kwargs)
|
| 228 |
+
|
| 229 |
+
# Patch the original_module's forward so we can redirect the arguments back to the real method.
|
| 230 |
+
self.module.forward = wrapped_training_step
|
| 231 |
+
# Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
|
| 232 |
+
# Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
|
| 233 |
+
return self(*args, **kwargs)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@contextmanager
|
| 237 |
+
def ddp_sync_grad(model, enabled):
|
| 238 |
+
r"""
|
| 239 |
+
Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
|
| 240 |
+
Modified from:
|
| 241 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
|
| 242 |
+
Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
|
| 243 |
+
|
| 244 |
+
Within this context, gradients will be accumulated on module
|
| 245 |
+
variables, which will later be synchronized in the first
|
| 246 |
+
forward-backward pass exiting the context.
|
| 247 |
+
|
| 248 |
+
.. warning::
|
| 249 |
+
The forward pass should be included inside the context manager, or
|
| 250 |
+
else gradients will still be synchronized.
|
| 251 |
+
"""
|
| 252 |
+
assert isinstance(model, torch.nn.Module)
|
| 253 |
+
if isinstance(model, DistributedDataParallel):
|
| 254 |
+
old_require_backward_grad_sync = model.require_backward_grad_sync
|
| 255 |
+
if model.static_graph and model.require_backward_grad_sync != enabled:
|
| 256 |
+
if model.show_sync_grad_static_graph_warning:
|
| 257 |
+
log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
|
| 258 |
+
model.show_sync_grad_static_graph_warning = False
|
| 259 |
+
else:
|
| 260 |
+
model.require_backward_grad_sync = enabled
|
| 261 |
+
try:
|
| 262 |
+
yield
|
| 263 |
+
finally:
|
| 264 |
+
if isinstance(model, DistributedDataParallel):
|
| 265 |
+
model.require_backward_grad_sync = old_require_backward_grad_sync
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
|
| 269 |
+
"""Aggregate the list of data batches from all devices and process the results.
|
| 270 |
+
|
| 271 |
+
This is used for gathering validation data batches with imaginaire.utils.dataloader.DistributedEvalSampler.
|
| 272 |
+
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
|
| 273 |
+
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
|
| 274 |
+
created before calling dis.all_gather().
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
|
| 278 |
+
leaf entries are tensors.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
|
| 282 |
+
leaf entries are concatenated tensors.
|
| 283 |
+
"""
|
| 284 |
+
if isinstance(data_batches[0], torch.Tensor):
|
| 285 |
+
# Concatenate the local data batches.
|
| 286 |
+
data_concat = torch.cat(data_batches, dim=0) # type: ignore
|
| 287 |
+
# Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
|
| 288 |
+
max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
|
| 289 |
+
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
|
| 290 |
+
if len(data_concat) < max_num_local_samples:
|
| 291 |
+
assert len(data_concat) + 1 == max_num_local_samples
|
| 292 |
+
dummy = torch.empty_like(data_concat[:1])
|
| 293 |
+
data_concat = torch.cat([data_concat, dummy], dim=0)
|
| 294 |
+
dummy_count = torch.tensor(1, device="cuda")
|
| 295 |
+
else:
|
| 296 |
+
dummy_count = torch.tensor(0, device="cuda")
|
| 297 |
+
# Get all concatenated batches from all ranks and concatenate again.
|
| 298 |
+
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
|
| 299 |
+
data_concat = all_gather_tensor(data_concat.contiguous())
|
| 300 |
+
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
|
| 301 |
+
# Remove the dummy samples.
|
| 302 |
+
if dummy_count > 0:
|
| 303 |
+
data_collate = data_collate[:-dummy_count]
|
| 304 |
+
elif isinstance(data_batches[0], collections.abc.Mapping):
|
| 305 |
+
data_collate = dict()
|
| 306 |
+
for key in data_batches[0].keys():
|
| 307 |
+
data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
|
| 308 |
+
else:
|
| 309 |
+
raise TypeError
|
| 310 |
+
return data_collate
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@torch.no_grad()
|
| 314 |
+
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
|
| 315 |
+
"""Gather the corresponding tensor from all GPU devices to a list.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
tensor (torch.Tensor): Pytorch tensor.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
|
| 322 |
+
"""
|
| 323 |
+
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
| 324 |
+
dist.all_gather(tensor_list, tensor)
|
| 325 |
+
return tensor_list
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def broadcast(tensor, src, group=None, async_op=False):
|
| 329 |
+
world_size = get_world_size()
|
| 330 |
+
if world_size < 2:
|
| 331 |
+
return tensor
|
| 332 |
+
dist.broadcast(tensor, src=src, group=group, async_op=async_op)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
|
| 336 |
+
r"""Reduce to rank 0"""
|
| 337 |
+
world_size = get_world_size()
|
| 338 |
+
if world_size < 2:
|
| 339 |
+
return tensor
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
dist.reduce(tensor, dst=rank)
|
| 342 |
+
if get_rank() == rank:
|
| 343 |
+
if reduce == "mean":
|
| 344 |
+
tensor /= world_size
|
| 345 |
+
elif reduce == "sum":
|
| 346 |
+
pass
|
| 347 |
+
else:
|
| 348 |
+
raise NotImplementedError
|
| 349 |
+
return tensor
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def sync_model_states(
|
| 353 |
+
model: torch.nn.Module,
|
| 354 |
+
process_group: dist.ProcessGroup | None = None,
|
| 355 |
+
src: int = 0,
|
| 356 |
+
params_and_buffers_to_ignore: Container[str] | None = None,
|
| 357 |
+
broadcast_buffers: bool = True,
|
| 358 |
+
):
|
| 359 |
+
"""
|
| 360 |
+
Modify based on DDP source code
|
| 361 |
+
Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
|
| 362 |
+
|
| 363 |
+
This function ensures that all processes in the specified process group have the same initial parameters and
|
| 364 |
+
buffers from the source rank, typically rank 0. It is useful when different processes start with different model
|
| 365 |
+
states and a synchronization is required to ensure consistency across all ranks.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
model (nn.Module): The model whose parameters and buffers are to be synchronized.
|
| 369 |
+
process_group (dist.ProcessGroup, optional): The process group for communication. If None,
|
| 370 |
+
the default group is used. Defaults to None.
|
| 371 |
+
src (int, optional): The source rank from which parameters and buffers will be broadcasted.
|
| 372 |
+
Defaults to 0.
|
| 373 |
+
params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
|
| 374 |
+
names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
|
| 375 |
+
included.
|
| 376 |
+
broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
|
| 377 |
+
|
| 378 |
+
Side Effects:
|
| 379 |
+
This function modifies the state of the model in-place to synchronize it with the source rank's model state.
|
| 380 |
+
|
| 381 |
+
Raises:
|
| 382 |
+
RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
|
| 383 |
+
|
| 384 |
+
Examples:
|
| 385 |
+
>>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
|
| 386 |
+
>>> # useful and save our time when model weights are huge
|
| 387 |
+
>>> if dist.get_rank == 0:
|
| 388 |
+
>>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
|
| 389 |
+
>>> dist.barrir()
|
| 390 |
+
>>> sync_model_states(model) # sync rank0 weights to other ranks
|
| 391 |
+
"""
|
| 392 |
+
if not dist.is_available() or not dist.is_initialized():
|
| 393 |
+
return
|
| 394 |
+
if process_group is None:
|
| 395 |
+
process_group = _get_default_group()
|
| 396 |
+
if not params_and_buffers_to_ignore:
|
| 397 |
+
params_and_buffers_to_ignore = set()
|
| 398 |
+
|
| 399 |
+
log.info(
|
| 400 |
+
f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Build tuple of (module, parameter) for all parameters that require grads.
|
| 404 |
+
modules_and_parameters = [
|
| 405 |
+
(module, parameter)
|
| 406 |
+
for module_name, module in model.named_modules()
|
| 407 |
+
for parameter in [
|
| 408 |
+
param
|
| 409 |
+
# Note that we access module.named_parameters instead of
|
| 410 |
+
# parameters(module). parameters(module) is only needed in the
|
| 411 |
+
# single-process multi device case, where it accesses replicated
|
| 412 |
+
# parameters through _former_parameters.
|
| 413 |
+
for param_name, param in module.named_parameters(recurse=False)
|
| 414 |
+
if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 415 |
+
# if param.requires_grad
|
| 416 |
+
# and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 417 |
+
]
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
# Deduplicate any parameters that might be shared across child modules.
|
| 421 |
+
memo = set()
|
| 422 |
+
modules_and_parameters = [
|
| 423 |
+
# "p not in memo" is the deduplication check.
|
| 424 |
+
# "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
|
| 425 |
+
(m, p)
|
| 426 |
+
for m, p in modules_and_parameters
|
| 427 |
+
if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# Build list of parameters.
|
| 431 |
+
parameters = [parameter for _, parameter in modules_and_parameters]
|
| 432 |
+
if len(parameters) == 0:
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
_verify_param_shape_across_processes(process_group, parameters)
|
| 436 |
+
|
| 437 |
+
_sync_module_states(
|
| 438 |
+
module=model,
|
| 439 |
+
process_group=process_group,
|
| 440 |
+
broadcast_bucket_size=(250 * 1024 * 1024),
|
| 441 |
+
src=src,
|
| 442 |
+
params_and_buffers_to_ignore=params_and_buffers_to_ignore,
|
| 443 |
+
broadcast_buffers=broadcast_buffers,
|
| 444 |
+
)
|
imaginaire/utils/easy_io/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
imaginaire/utils/easy_io/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (144 Bytes). View file
|
|
|
imaginaire/utils/easy_io/__pycache__/easy_io.cpython-310.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
imaginaire/utils/easy_io/__pycache__/file_client.cpython-310.pyc
ADDED
|
Binary file (14.9 kB). View file
|
|
|
imaginaire/utils/easy_io/backends/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 17 |
+
from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
|
| 18 |
+
from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
|
| 19 |
+
from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"BaseStorageBackend",
|
| 23 |
+
"HTTPBackend",
|
| 24 |
+
"LocalBackend",
|
| 25 |
+
"backends",
|
| 26 |
+
"prefix_to_backends",
|
| 27 |
+
"register_backend",
|
| 28 |
+
]
|
imaginaire/utils/easy_io/backends/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (596 Bytes). View file
|
|
|
imaginaire/utils/easy_io/backends/__pycache__/base_backend.cpython-310.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
imaginaire/utils/easy_io/backends/__pycache__/http_backend.cpython-310.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
imaginaire/utils/easy_io/backends/__pycache__/local_backend.cpython-310.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
imaginaire/utils/easy_io/backends/__pycache__/registry_utils.cpython-310.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|