yujiwang0606 commited on
Commit
230e1e3
·
1 Parent(s): 1b3e11b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -5
  2. app.py +131 -4
  3. imaginaire/.DS_Store +0 -0
  4. imaginaire/__init__.py +14 -0
  5. imaginaire/__pycache__/__init__.cpython-310.pyc +0 -0
  6. imaginaire/__pycache__/__init__.cpython-39.pyc +0 -0
  7. imaginaire/callbacks/__init__.py +14 -0
  8. imaginaire/callbacks/every_n.py +84 -0
  9. imaginaire/callbacks/manual_gc.py +49 -0
  10. imaginaire/config.py +410 -0
  11. imaginaire/lazy_config/__init__.py +73 -0
  12. imaginaire/lazy_config/__pycache__/__init__.cpython-310.pyc +0 -0
  13. imaginaire/lazy_config/__pycache__/file_io.cpython-310.pyc +0 -0
  14. imaginaire/lazy_config/__pycache__/instantiate.cpython-310.pyc +0 -0
  15. imaginaire/lazy_config/__pycache__/lazy.cpython-310.pyc +0 -0
  16. imaginaire/lazy_config/__pycache__/omegaconf_patch.cpython-310.pyc +0 -0
  17. imaginaire/lazy_config/__pycache__/registry.cpython-310.pyc +0 -0
  18. imaginaire/lazy_config/file_io.py +24 -0
  19. imaginaire/lazy_config/instantiate.py +119 -0
  20. imaginaire/lazy_config/lazy.py +442 -0
  21. imaginaire/lazy_config/omegaconf_patch.py +65 -0
  22. imaginaire/lazy_config/registry.py +74 -0
  23. imaginaire/model.py +137 -0
  24. imaginaire/trainer.py +322 -0
  25. imaginaire/utils/.DS_Store +0 -0
  26. imaginaire/utils/__init__.py +14 -0
  27. imaginaire/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  28. imaginaire/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  29. imaginaire/utils/__pycache__/device.cpython-310.pyc +0 -0
  30. imaginaire/utils/__pycache__/distributed.cpython-310.pyc +0 -0
  31. imaginaire/utils/__pycache__/io.cpython-310.pyc +0 -0
  32. imaginaire/utils/__pycache__/io.cpython-39.pyc +0 -0
  33. imaginaire/utils/__pycache__/log.cpython-310.pyc +0 -0
  34. imaginaire/utils/__pycache__/log.cpython-39.pyc +0 -0
  35. imaginaire/utils/__pycache__/misc.cpython-310.pyc +0 -0
  36. imaginaire/utils/callback.py +518 -0
  37. imaginaire/utils/checkpointer.py +282 -0
  38. imaginaire/utils/config_helper.py +201 -0
  39. imaginaire/utils/device.py +39 -0
  40. imaginaire/utils/distributed.py +444 -0
  41. imaginaire/utils/easy_io/__init__.py +14 -0
  42. imaginaire/utils/easy_io/__pycache__/__init__.cpython-310.pyc +0 -0
  43. imaginaire/utils/easy_io/__pycache__/easy_io.cpython-310.pyc +0 -0
  44. imaginaire/utils/easy_io/__pycache__/file_client.cpython-310.pyc +0 -0
  45. imaginaire/utils/easy_io/backends/__init__.py +28 -0
  46. imaginaire/utils/easy_io/backends/__pycache__/__init__.cpython-310.pyc +0 -0
  47. imaginaire/utils/easy_io/backends/__pycache__/base_backend.cpython-310.pyc +0 -0
  48. imaginaire/utils/easy_io/backends/__pycache__/http_backend.cpython-310.pyc +0 -0
  49. imaginaire/utils/easy_io/backends/__pycache__/local_backend.cpython-310.pyc +0 -0
  50. imaginaire/utils/easy_io/backends/__pycache__/registry_utils.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: RCM Wan
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
 
10
  short_description: rCM model for Wan2.1
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: rCM-Wan
3
+ emoji: 🦀
4
+ colorFrom: pink
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  short_description: rCM model for Wan2.1
12
  ---
13
 
14
+ This demo uses the unofficial rCM models for Wan from [worstcoder/rcm-Wan](https://huggingface.co/worstcoder/rcm-Wan).
app.py CHANGED
@@ -1,7 +1,134 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import time
4
+ import os
5
+ import requests
6
+ from wan2pt1_t2v_rcm_infer import inference, prepare_models
7
+ from huggingface_hub import hf_hub_download
8
+ import random
9
+ from types import SimpleNamespace
10
+ import gc
11
+ import torch
12
 
13
+ import flash_attn
14
+ print("flash_attn version: ", flash_attn.__version__)
15
 
16
+ dit_path_1p3B = hf_hub_download(
17
+ repo_id="worstcoder/rcm-Wan",
18
+ filename="rCM_Wan2.1_T2V_1.3B_480p.pt",
19
+ )
20
+
21
+ dit_path_14B = hf_hub_download(
22
+ repo_id="worstcoder/rcm-Wan",
23
+ filename="rCM_Wan2.1_T2V_14B_480p.pt",
24
+ )
25
+
26
+ vae_path = hf_hub_download(
27
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
28
+ filename="Wan2.1_VAE.pth"
29
+ )
30
+
31
+ text_encoder_path = hf_hub_download(
32
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
33
+ filename="models_t5_umt5-xxl-enc-bf16.pth"
34
+ )
35
+
36
+ net_1p3B, net_14B, tokenizer, t5_encoder = prepare_models(dit_path_1p3B, dit_path_14B, vae_path, text_encoder_path)
37
+ print("Loaded models")
38
+ gc.collect()
39
+
40
+ def random_seed():
41
+ return random.randint(0, 2**32 - 1)
42
+
43
+ @spaces.GPU(duration=120)
44
+ def generate_videos(prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed):
45
+ if seed is None:
46
+ seed = random.randint(0, 2**32 - 1)
47
+
48
+ args = SimpleNamespace(
49
+ prompt=prompt,
50
+ model_size=model_size,
51
+ num_steps=num_steps,
52
+ num_samples=num_samples,
53
+ sigma_max=sigma_max,
54
+ num_frames=77,
55
+ resolution="480p",
56
+ aspect_ratio=aspect_ratio,
57
+ seed=seed,
58
+ )
59
+
60
+ with torch.no_grad():
61
+ video_list = inference(args, net_1p3B, net_14B, tokenizer, t5_encoder)
62
+
63
+ if aspect_ratio == "16:9":
64
+ return video_list, None
65
+ else:
66
+ return None, video_list
67
+
68
+ def update_num_samples(model_choice):
69
+ if model_choice == "rCM-Wan2.1-T2V-1.3B-480p":
70
+ options = [1, 2, 3, 4]
71
+ else:
72
+ options = [1, 2, 3]
73
+ return gr.Dropdown(choices=options, value=options[0], label="num_samples")
74
+
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("## rCM model for Wan")
77
+
78
+ examples = [
79
+ ["A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."],
80
+ ["A close-up shot captures a steaming hot pot brimming with vegetables and dumplings, set on a rustic wooden table. The camera focuses on the bubbling broth as a woman, dressed in a light, patterned blouse, reaches in with chopsticks to lift a tender leaf of cabbage from the simmering mixture. Steam rises around her as she leans back slightly, her warm smile reflecting satisfaction and joy. Her movements are smooth and deliberate, showcasing her comfort and familiarity with the dining process. The background includes a small bowl of dipping sauce and a clay pot, adding to the cozy, communal dining atmosphere."],
81
+ ["A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond."]
82
+ ]
83
+
84
+ with gr.Row():
85
+ with gr.Column(scale=1):
86
+ with gr.Row():
87
+ prompt = gr.Textbox(label="Text prompt", placeholder="Text prompt for videos")
88
+ model_size = gr.Radio(
89
+ ["rCM-Wan2.1-T2V-1.3B-480p", "rCM-Wan2.1-T2V-14B-480p"],
90
+ value="rCM-Wan2.1-T2V-1.3B-480p",
91
+ label="Model"
92
+ )
93
+
94
+ with gr.Row():
95
+ num_samples = gr.Dropdown([1, 2, 3, 4], value=1, label="num_samples")
96
+ aspect_ratio = gr.Radio(["16:9", "9:16"], value="16:9", label="aspect_ratio")
97
+ sigma_max = gr.Dropdown([40, 80, 120, 200, 400, 800, 1600], value=80, label="sigma_max")
98
+
99
+ with gr.Row():
100
+ num_steps = gr.Slider(1, 4, value=4, step=1, label="num_steps")
101
+ seed = gr.Number(label="seed", value=random_seed(), interactive=True)
102
+
103
+ with gr.Row():
104
+ regenerate_btn = gr.Button("New Seed")
105
+ run_btn = gr.Button("Generate Videos")
106
+
107
+ with gr.Row():
108
+ gr.Examples(
109
+ examples,
110
+ inputs=[prompt],
111
+ label="Example prompts"
112
+ )
113
+
114
+ with gr.Column(scale=1):
115
+ video_16_9 = gr.Video(label="Videos 16:9", width=832)
116
+ video_9_16 = gr.Video(label="Videos 9:16", width=480, visible=False)
117
+
118
+ def show_video(aspect):
119
+ if aspect == "16:9":
120
+ return gr.update(visible=True), gr.update(visible=False, value=None)
121
+ else:
122
+ return gr.update(visible=False, value=None), gr.update(visible=True)
123
+
124
+ model_size.change(fn=update_num_samples, inputs=model_size, outputs=num_samples)
125
+ aspect_ratio.change(show_video, inputs=aspect_ratio, outputs=[video_16_9, video_9_16])
126
+ regenerate_btn.click(fn=random_seed, outputs=seed)
127
+
128
+ run_btn.click(
129
+ fn=generate_videos,
130
+ inputs=[prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed],
131
+ outputs=[video_16_9, video_9_16],
132
+ )
133
+
134
+ demo.launch()
imaginaire/.DS_Store ADDED
Binary file (6.15 kB). View file
 
imaginaire/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (130 Bytes). View file
 
imaginaire/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (145 Bytes). View file
 
imaginaire/callbacks/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/callbacks/every_n.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import abstractmethod
17
+
18
+ import torch
19
+
20
+ from imaginaire.model import ImaginaireModel
21
+ from imaginaire.trainer import ImaginaireTrainer
22
+ from imaginaire.utils import distributed, log
23
+ from imaginaire.utils.callback import Callback
24
+
25
+
26
+ class EveryN(Callback):
27
+ def __init__(
28
+ self,
29
+ every_n: int | None = None,
30
+ step_size: int = 1,
31
+ barrier_after_run: bool = True,
32
+ run_at_start: bool = False,
33
+ ) -> None:
34
+ """Constructor for `EveryN`.
35
+
36
+ Args:
37
+ every_n (int): Frequency with which callback is run during training.
38
+ step_size (int): Size of iteration step count. Default 1.
39
+ barrier_after_run (bool): Whether to have a distributed barrier after each execution. Default True, to avoid timeouts.
40
+ run_at_start (bool): Whether to run at the beginning of training. Default False.
41
+ """
42
+ self.every_n = every_n
43
+ if self.every_n == 0:
44
+ log.warning(
45
+ f"every_n is set to 0. Callback {self.__class__.__name__} will be invoked only once in the beginning of the training. Calls happens on_training_step_end will be skipped."
46
+ )
47
+
48
+ self.step_size = step_size
49
+ self.barrier_after_run = barrier_after_run
50
+ self.run_at_start = run_at_start
51
+
52
+ def on_training_step_end(
53
+ self,
54
+ model: ImaginaireModel,
55
+ data_batch: dict[str, torch.Tensor],
56
+ output_batch: dict[str, torch.Tensor],
57
+ loss: torch.Tensor,
58
+ iteration: int = 0,
59
+ ) -> None:
60
+ # every_n = 0 is a special case which means every_n_impl will be called only once in the beginning of the training
61
+ if self.every_n != 0:
62
+ trainer = self.trainer
63
+ global_step = iteration // self.step_size
64
+ should_run = (iteration == 1 and self.run_at_start) or (
65
+ global_step % self.every_n == 0
66
+ ) # (self.every_n - 1)
67
+ if should_run:
68
+ log.debug(f"Callback {self.__class__.__name__} fired on train_batch_end step {global_step}")
69
+ self.every_n_impl(trainer, model, data_batch, output_batch, loss, iteration)
70
+ log.debug(f"Callback {self.__class__.__name__} finished on train_batch_end step {global_step}")
71
+ # add necessary barrier to avoid timeout
72
+ if self.barrier_after_run:
73
+ distributed.barrier()
74
+
75
+ @abstractmethod
76
+ def every_n_impl(
77
+ self,
78
+ trainer: ImaginaireTrainer,
79
+ model: ImaginaireModel,
80
+ data_batch: dict[str, torch.Tensor],
81
+ output_batch: dict[str, torch.Tensor],
82
+ loss: torch.Tensor,
83
+ iteration: int,
84
+ ) -> None: ...
imaginaire/callbacks/manual_gc.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+
18
+ from imaginaire.callbacks.every_n import EveryN
19
+ from imaginaire.utils import log
20
+
21
+
22
+ class ManualGarbageCollection(EveryN):
23
+ """
24
+ Disable auto gc and manually trigger garbage collection every N iterations
25
+ It is super useful for large scale training to reduce gpu sync time!
26
+ Can reach 50% speedup.
27
+
28
+ It is important to note that this callback only disables gc in main process and have auto gc enabled in subprocesses.
29
+
30
+ We start disable gc after warm_up iterations to avoid disabling gc in subprocesses, such as dataloader, which can cause OOM
31
+ """
32
+
33
+ def __init__(self, *args, warm_up: int = 5, **kwargs):
34
+ kwargs["barrier_after_run"] = False
35
+ super().__init__(*args, **kwargs)
36
+
37
+ self.counter = 0
38
+ self.warm = warm_up
39
+
40
+ def every_n_impl(self, trainer, model, data_batch, output_batch, loss, iteration):
41
+ del trainer, model, data_batch, output_batch, loss
42
+ self.counter += 1
43
+ if self.counter < self.warm:
44
+ return
45
+ if self.counter == self.warm:
46
+ gc.disable()
47
+ log.critical("Garbage collection disabled")
48
+
49
+ gc.collect(1)
imaginaire/config.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Training config system for Imaginare4"""
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ from typing import Any, TypeVar
22
+
23
+ import attrs
24
+ import torch
25
+ import torch.utils.data
26
+
27
+ from imaginaire.model import ImaginaireModel
28
+
29
+ try:
30
+ from megatron.core import ModelParallelConfig
31
+
32
+ USE_MEGATRON = True
33
+ except ImportError:
34
+ USE_MEGATRON = False
35
+ print("Megatron-core is not installed.")
36
+
37
+ import builtins
38
+
39
+ from imaginaire.lazy_config import LazyCall as L
40
+ from imaginaire.lazy_config import LazyDict
41
+ from imaginaire.utils import callback, distributed
42
+ from imaginaire.utils.misc import Color
43
+
44
+ T = TypeVar("T")
45
+
46
+
47
+ def _is_attrs_instance(obj: object) -> bool:
48
+ """
49
+ Helper function to check if an object is an instance of an attrs-defined class.
50
+
51
+ Args:
52
+ obj: The object to check.
53
+
54
+ Returns:
55
+ bool: True if the object is an instance of an attrs-defined class, False otherwise.
56
+ """
57
+ return hasattr(obj, "__attrs_attrs__")
58
+
59
+
60
+ def make_freezable(cls: T) -> T:
61
+ """
62
+ A decorator that adds the capability to freeze instances of an attrs-defined class.
63
+
64
+ NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
65
+ to hack on a "_is_frozen" attribute.
66
+
67
+ This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
68
+ Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
69
+ any attrs-defined objects that are attributes of the class.
70
+
71
+ Usage:
72
+ @make_freezable
73
+ @attrs.define(slots=False)
74
+ class MyClass:
75
+ attribute1: int
76
+ attribute2: str
77
+
78
+ obj = MyClass(1, 'a')
79
+ obj.freeze() # Freeze the instance
80
+ obj.attribute1 = 2 # Raises AttributeError
81
+
82
+ Args:
83
+ cls: The class to be decorated.
84
+
85
+ Returns:
86
+ The decorated class with added freezing capability.
87
+ """
88
+
89
+ if not hasattr(cls, "__dict__"):
90
+ raise TypeError(
91
+ "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
92
+ "class was defined with `@attrs.define(slots=False)`"
93
+ )
94
+
95
+ original_setattr = cls.__setattr__
96
+
97
+ def setattr_override(self, key, value) -> None:
98
+ """
99
+ Override __setattr__ to allow modifications during initialization
100
+ and prevent modifications once the instance is frozen.
101
+ """
102
+ if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
103
+ raise AttributeError("Cannot modify frozen instance")
104
+ original_setattr(self, key, value) # type: ignore
105
+
106
+ cls.__setattr__ = setattr_override # type: ignore
107
+
108
+ def freeze(self: object) -> None:
109
+ """
110
+ Freeze the instance and all its attrs-defined attributes.
111
+ """
112
+ for _, value in attrs.asdict(self, recurse=False).items():
113
+ if _is_attrs_instance(value) and hasattr(value, "freeze"):
114
+ value.freeze()
115
+ self._is_frozen = True # type: ignore
116
+
117
+ cls.freeze = freeze # type: ignore
118
+
119
+ return cls
120
+
121
+
122
+ def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
123
+ """
124
+ Recursively pretty prints attrs objects with color.
125
+ """
126
+
127
+ assert attrs.has(obj.__class__)
128
+
129
+ lines: list[str] = []
130
+ for attribute in attrs.fields(obj.__class__):
131
+ value = getattr(obj, attribute.name)
132
+ if attrs.has(value.__class__):
133
+ if use_color:
134
+ lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
135
+ else:
136
+ lines.append(" " * indent + "* " + attribute.name + ":")
137
+ lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
138
+ else:
139
+ if use_color:
140
+ lines.append(
141
+ " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
142
+ )
143
+ else:
144
+ lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
145
+ return "\n".join(lines)
146
+
147
+
148
+ def pretty_print_overrides(overrides: list[str] | None = None, use_color: bool = False) -> str:
149
+ """
150
+ Pretty prints overrides.
151
+ """
152
+
153
+ lines: list[str] = []
154
+ lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
155
+ for override in overrides:
156
+ if override == "--":
157
+ continue
158
+ if override.startswith("~"):
159
+ attribute_name = override[1:]
160
+ attribute_value = None
161
+ else:
162
+ attribute_name, attribute_value = override.split("=")
163
+ if use_color:
164
+ lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
165
+ else:
166
+ lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
167
+
168
+ return "\n".join(lines)
169
+
170
+
171
+ @make_freezable
172
+ @attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
173
+ class ObjectStoreConfig:
174
+ # Whether the file I/O is from object store instead of local disk.
175
+ enabled: bool = False
176
+ # Path to the object store credentials file.
177
+ credentials: str = ""
178
+ # Object store bucket to read from / write to the objects.
179
+ bucket: str = ""
180
+
181
+
182
+ @make_freezable
183
+ @attrs.define(slots=False)
184
+ class JobConfig:
185
+ # Project name.
186
+ project: str = ""
187
+ # Experiment name.
188
+ group: str = ""
189
+ # Run/job name.
190
+ name: str = ""
191
+
192
+ @property
193
+ def path(self) -> str:
194
+ return f"{self.project}/{self.group}/{self.name}"
195
+
196
+ @property
197
+ def path_local(self) -> str:
198
+ local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "checkpoints")
199
+ return f"{local_root}/{self.path}"
200
+
201
+
202
+ @make_freezable
203
+ @attrs.define(slots=False)
204
+ class EMAConfig:
205
+ # Enable tracking a set of exponential moving average (EMA) weights.
206
+ enabled: bool = False
207
+ # EMA decay rate.
208
+ beta: float = 0.9999
209
+ # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
210
+ torch_compile_buffer_renaming: bool = False
211
+
212
+
213
+ @make_freezable
214
+ @attrs.define(slots=False)
215
+ class PowerEMAConfig:
216
+ # Enable tracking a set of exponential moving average (EMA) weights.
217
+ enabled: bool = False
218
+ # EDM2 paper EMA decay rate.
219
+ s: float = 0.1
220
+ # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
221
+ torch_compile_buffer_renaming: bool = False
222
+
223
+
224
+ @make_freezable
225
+ @attrs.define(slots=False)
226
+ class DDPConfig:
227
+ # Traverse the computation graph to find parameters that don't receive gradients.
228
+ find_unused_parameters: bool = False
229
+ # Set to True if the computation graph does not change during the whole training loop.
230
+ static_graph: bool = True
231
+ # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
232
+ broadcast_buffers: bool = True
233
+
234
+
235
+ @make_freezable
236
+ @attrs.define(slots=False)
237
+ class CuDNNConfig:
238
+ # Set to True for better reproducibility of the results (only using deterministic cudnn functions).
239
+ deterministic: bool = False
240
+ # If set to True, cudnn will benchmark several algorithms and pick the fastest one.
241
+ benchmark: bool = True
242
+
243
+
244
+ @make_freezable
245
+ @attrs.define(slots=False)
246
+ class JITConfig:
247
+ # Enable exporting a JIT compiled model.
248
+ enabled: bool = False
249
+ # Input tensor shape, for example input.
250
+ input_shape: list[int] | None = None
251
+ # Device to compile onto.
252
+ device: str = "cuda"
253
+ # # Data type to compile onto.
254
+ dtype: str = "bfloat16"
255
+ # Strict mode for PyTorch JIT.
256
+ strict: bool = True
257
+
258
+
259
+ @make_freezable
260
+ @attrs.define(slots=False)
261
+ class CheckpointConfig:
262
+ # possible checkpoint class
263
+ type: dict | None = None
264
+ # for dcp, whether to use async mode
265
+ dcp_async_mode_enabled: bool = False
266
+ # Save the checkpoint every N iterations.
267
+ save_iter: int = 999999999
268
+ # Path of model weights to resume the checkpoint from.
269
+ load_path: str = ""
270
+ # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
271
+ load_training_state: bool = False
272
+ # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
273
+ only_load_scheduler_state: bool = False
274
+ # Load state_dict to the models in strict mode.
275
+ strict_resume: bool = True
276
+ # Configs for JIT compiling EMA model.
277
+ jit: JITConfig = attrs.field(factory=JITConfig)
278
+ # Print detailed information during checkpoint saving/loading.
279
+ verbose: bool = True
280
+ # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
281
+ keys_not_to_resume: list[str] = [] # noqa: RUF008
282
+ # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
283
+ broadcast_via_filesystem: bool = False
284
+ load_ema_to_reg: bool = False
285
+ # In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
286
+ dcp_allow_mismatched_size: bool = False
287
+
288
+
289
+ @make_freezable
290
+ @attrs.define(slots=False)
291
+ class NVTXConfig:
292
+ """Config for NVTX ranges used in the main training loop.
293
+
294
+ See tutorials/nanogpt for more details on how to integrate profiling into your model."""
295
+
296
+ # Enable the NVTX ranges.
297
+ enabled: bool = False
298
+ # Synchronize everything in each NVTX range.
299
+ cuda_synchronize: bool = False
300
+
301
+
302
+ @make_freezable
303
+ @attrs.define(slots=False)
304
+ class Profiling:
305
+ enable_profiling: bool = False
306
+ enable_memory_snapshot: bool = False
307
+ profile_freq: int = 1
308
+ first_n_rank: int = 8 # -1 means all ranks, n means first n ranks dumpy profiling info
309
+ record_shape: bool = True
310
+ profile_memory: bool = True
311
+ with_stack: bool = True
312
+ with_modules: bool = True
313
+
314
+
315
+ @make_freezable
316
+ @attrs.define(slots=False)
317
+ class TrainerConfig:
318
+ from imaginaire.trainer import ImaginaireTrainer
319
+
320
+ type: builtins.type[ImaginaireTrainer] = ImaginaireTrainer
321
+ # Set the callback class.
322
+ # Defaults to the callbacks below.
323
+ callbacks: LazyDict[dict[str, callback.Callback]] = LazyDict( # noqa: RUF009
324
+ dict(
325
+ ema=L(callback.EMAModelCallback)(),
326
+ progress_bar=L(callback.ProgressBarCallback)(),
327
+ )
328
+ )
329
+ # distributed parallelism strategy
330
+ distributed_parallelism: str = "ddp"
331
+ # Distributed data parallel configs.
332
+ ddp: DDPConfig = attrs.field(factory=DDPConfig)
333
+ # cuDNN configs.
334
+ cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
335
+ # Set the random seed.
336
+ seed: int = 0
337
+ # Gradient scaler arguments (for torch.amp.GradScaler).
338
+ grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
339
+ # Maximum number of iterations to train the model.
340
+ max_iter: int = 999999999
341
+ # Maximum number of iterations to validate the model. If None, validate on the entire dataset.
342
+ max_val_iter: int | None = None
343
+ # How often we log the training stats.
344
+ logging_iter: int = 100
345
+ # Whether we want to run the validation routines.
346
+ run_validation: bool = True
347
+ # How often we evaluate on the validation set.
348
+ validation_iter: int = 999999999
349
+ # Kill the process after N seconds since the last iteration (usually means dead job).
350
+ timeout_period: int = 999999999
351
+ # Tensor memory organization format.
352
+ memory_format: torch.memory_format = torch.preserve_format
353
+ # Gradient accumulation (update step every N iteration).
354
+ grad_accum_iter: int = 1
355
+ # Profiling config
356
+ profiling: Profiling = attrs.field(factory=Profiling)
357
+
358
+
359
+ @make_freezable
360
+ @attrs.define(slots=False)
361
+ class Config:
362
+ """Config for an imaginaire4 job.
363
+
364
+ See /README.md/Configuration System for more info.
365
+ """
366
+
367
+ # Model configs.
368
+ model: LazyDict[ImaginaireModel]
369
+ # Optimizer configs.
370
+ optimizer: LazyDict[torch.optim.Optimizer]
371
+ # Scheduler configs.
372
+ scheduler: LazyDict[torch.optim.lr_scheduler.LRScheduler]
373
+ # Training data configs.
374
+ dataloader_train: LazyDict[torch.utils.data.DataLoader]
375
+ # Validation data configs.
376
+ dataloader_val: LazyDict[torch.utils.data.DataLoader]
377
+
378
+ # Training job configs.
379
+ job: JobConfig = attrs.field(factory=JobConfig)
380
+
381
+ # Trainer configs.
382
+ trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
383
+
384
+ if USE_MEGATRON:
385
+ # Megatron-Core configs
386
+ model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
387
+ else:
388
+ model_parallel: None = None
389
+
390
+ # Checkpointer configs.
391
+ checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
392
+
393
+ def pretty_print(self, use_color: bool = False) -> str:
394
+ return _pretty_print_attrs_instance(self, 0, use_color)
395
+
396
+ def to_dict(self) -> dict[str, Any]:
397
+ return attrs.asdict(self)
398
+
399
+ def validate(self) -> None:
400
+ """Validate that the config has all required fields."""
401
+
402
+ # broadcast job.name across all ranks to make sure it is consistent
403
+ # otherwise, unaligned job names leads unaligned path to save checkpoints
404
+ job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
405
+ distributed.broadcast(job_name_tensor, 0)
406
+ self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
407
+
408
+ assert self.job.project != ""
409
+ assert self.job.group != ""
410
+ assert self.job.name != ""
imaginaire/lazy_config/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ from imaginaire.lazy_config.instantiate import instantiate
21
+ from imaginaire.lazy_config.lazy import LazyCall, LazyConfig, LazyDict
22
+ from imaginaire.lazy_config.omegaconf_patch import to_object
23
+
24
+ OmegaConf.to_object = to_object
25
+
26
+ PLACEHOLDER = None
27
+
28
+ __all__ = ["PLACEHOLDER", "LazyCall", "LazyConfig", "LazyDict", "instantiate"]
29
+
30
+
31
+ DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
32
+
33
+
34
+ def fixup_module_metadata(module_name, namespace, keys=None):
35
+ """
36
+ Fix the __qualname__ of module members to be their exported api name, so
37
+ when they are referenced in docs, sphinx can find them. Reference:
38
+ https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
39
+ """
40
+ if not DOC_BUILDING:
41
+ return
42
+ seen_ids = set()
43
+
44
+ def fix_one(qualname, name, obj):
45
+ # avoid infinite recursion (relevant when using
46
+ # typing.Generic, for example)
47
+ if id(obj) in seen_ids:
48
+ return
49
+ seen_ids.add(id(obj))
50
+
51
+ mod = getattr(obj, "__module__", None)
52
+ if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
53
+ obj.__module__ = module_name
54
+ # Modules, unlike everything else in Python, put fully-qualitied
55
+ # names into their __name__ attribute. We check for "." to avoid
56
+ # rewriting these.
57
+ if hasattr(obj, "__name__") and "." not in obj.__name__:
58
+ obj.__name__ = name
59
+ obj.__qualname__ = qualname
60
+ if isinstance(obj, type):
61
+ for attr_name, attr_value in obj.__dict__.items():
62
+ fix_one(objname + "." + attr_name, attr_name, attr_value)
63
+
64
+ if keys is None:
65
+ keys = namespace.keys()
66
+ for objname in keys:
67
+ if not objname.startswith("_"):
68
+ obj = namespace[objname]
69
+ fix_one(objname, objname, obj)
70
+
71
+
72
+ fixup_module_metadata(__name__, globals(), __all__)
73
+ del fixup_module_metadata
imaginaire/lazy_config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
imaginaire/lazy_config/__pycache__/file_io.cpython-310.pyc ADDED
Binary file (387 Bytes). View file
 
imaginaire/lazy_config/__pycache__/instantiate.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
imaginaire/lazy_config/__pycache__/lazy.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
imaginaire/lazy_config/__pycache__/omegaconf_patch.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
 
imaginaire/lazy_config/__pycache__/registry.cpython-310.pyc ADDED
Binary file (1.33 kB). View file
 
imaginaire/lazy_config/file_io.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
17
+ from iopath.common.file_io import PathManager as PathManagerBase
18
+
19
+ __all__ = ["PathHandler", "PathManager"]
20
+
21
+
22
+ PathManager = PathManagerBase()
23
+ PathManager.register_handler(HTTPURLHandler())
24
+ PathManager.register_handler(OneDrivePathHandler())
imaginaire/lazy_config/instantiate.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections.abc as abc
17
+ import dataclasses
18
+ from typing import Any
19
+
20
+ import attrs
21
+
22
+ from imaginaire.lazy_config.registry import _convert_target_to_string, locate
23
+ from imaginaire.utils import log
24
+
25
+ __all__ = ["dump_dataclass", "instantiate"]
26
+
27
+
28
+ def is_dataclass_or_attrs(target):
29
+ return dataclasses.is_dataclass(target) or attrs.has(target)
30
+
31
+
32
+ def dump_dataclass(obj: Any):
33
+ """
34
+ Dump a dataclass recursively into a dict that can be later instantiated.
35
+
36
+ Args:
37
+ obj: a dataclass object
38
+
39
+ Returns:
40
+ dict
41
+ """
42
+ assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
43
+ "dump_dataclass() requires an instance of a dataclass."
44
+ )
45
+ ret = {"_target_": _convert_target_to_string(type(obj))}
46
+ for f in dataclasses.fields(obj):
47
+ v = getattr(obj, f.name)
48
+ if dataclasses.is_dataclass(v):
49
+ v = dump_dataclass(v)
50
+ if isinstance(v, (list, tuple)):
51
+ v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
52
+ ret[f.name] = v
53
+ return ret
54
+
55
+
56
+ def instantiate(cfg, *args, **kwargs):
57
+ """
58
+ Recursively instantiate objects defined in dictionaries by
59
+ "_target_" and arguments.
60
+
61
+ Args:
62
+ cfg: a dict-like object with "_target_" that defines the caller, and
63
+ other keys that define the arguments
64
+ args: Optional positional parameters pass-through.
65
+ kwargs: Optional named parameters pass-through.
66
+
67
+ Returns:
68
+ object instantiated by cfg
69
+ """
70
+ from omegaconf import DictConfig, ListConfig, OmegaConf
71
+
72
+ if isinstance(cfg, ListConfig):
73
+ lst = [instantiate(x) for x in cfg]
74
+ return ListConfig(lst, flags={"allow_objects": True})
75
+ if isinstance(cfg, list):
76
+ # Specialize for list, because many classes take
77
+ # list[objects] as arguments, such as ResNet, DatasetMapper
78
+ return [instantiate(x) for x in cfg]
79
+
80
+ # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
81
+ # instantiate it to the actual dataclass.
82
+ if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
83
+ return OmegaConf.to_object(cfg)
84
+
85
+ if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
86
+ # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
87
+ # but faster: https://github.com/facebookresearch/hydra/issues/1200
88
+ is_recursive = getattr(cfg, "_recursive_", True)
89
+ if is_recursive:
90
+ cfg = {k: instantiate(v) for k, v in cfg.items()}
91
+ else:
92
+ cfg = {k: v for k, v in cfg.items()}
93
+ # pop the _recursive_ key to avoid passing it as a parameter
94
+ if "_recursive_" in cfg:
95
+ cfg.pop("_recursive_")
96
+ cls = cfg.pop("_target_")
97
+ cls = instantiate(cls)
98
+
99
+ if isinstance(cls, str):
100
+ cls_name = cls
101
+ cls = locate(cls_name)
102
+ assert cls is not None, cls_name
103
+ else:
104
+ try:
105
+ cls_name = cls.__module__ + "." + cls.__qualname__
106
+ except Exception:
107
+ # target could be anything, so the above could fail
108
+ cls_name = str(cls)
109
+ assert callable(cls), f"_target_ {cls} does not define a callable object"
110
+ try:
111
+ # override config with kwargs
112
+ instantiate_kwargs = {}
113
+ instantiate_kwargs.update(cfg)
114
+ instantiate_kwargs.update(kwargs)
115
+ return cls(*args, **instantiate_kwargs)
116
+ except TypeError:
117
+ log.error(f"Error when instantiating {cls_name}!")
118
+ raise
119
+ return cfg # return as-is if don't know what to do
imaginaire/lazy_config/lazy.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ast
17
+ import builtins
18
+ import collections.abc as abc
19
+ import importlib
20
+ import inspect
21
+ import logging
22
+ import os
23
+ import pickle
24
+ import uuid
25
+ from collections import OrderedDict
26
+ from contextlib import contextmanager
27
+ from copy import deepcopy
28
+ from dataclasses import is_dataclass
29
+ from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, cast
30
+
31
+ import attrs
32
+ import yaml
33
+ from omegaconf import DictConfig, ListConfig, OmegaConf
34
+
35
+ from imaginaire.utils import log
36
+
37
+ try:
38
+ import dill as dill_pickle
39
+ except ImportError:
40
+ dill_pickle = None
41
+
42
+ try:
43
+ import cloudpickle
44
+ except ImportError:
45
+ cloudpickle = None
46
+
47
+ from imaginaire.lazy_config.file_io import PathManager
48
+ from imaginaire.lazy_config.registry import _convert_target_to_string
49
+
50
+ __all__ = ["LazyCall", "LazyConfig", "LazyDict"]
51
+
52
+ T = TypeVar("T")
53
+
54
+
55
+ def sort_dict(d: dict[str, Any]) -> OrderedDict[str, Any]:
56
+ return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
57
+
58
+
59
+ def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
60
+ return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
61
+
62
+
63
+ def sort_recursive(obj: dict[str, Any] | list[Any] | Any) -> OrderedDict[str, Any] | list[Any] | Any:
64
+ if isinstance(obj, dict):
65
+ return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
66
+ elif isinstance(obj, list):
67
+ return [sort_recursive(item) for item in obj]
68
+ return obj
69
+
70
+
71
+ yaml.add_representer(OrderedDict, dict_representer)
72
+
73
+ OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
74
+ OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
75
+
76
+
77
+ def get_default_params(cls_or_func):
78
+ if callable(cls_or_func):
79
+ # inspect signature for function
80
+ signature = inspect.signature(cls_or_func)
81
+ else:
82
+ # inspect signature for class
83
+ signature = inspect.signature(cls_or_func.__init__)
84
+ params = signature.parameters
85
+ default_params = {
86
+ name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
87
+ }
88
+ return default_params
89
+
90
+
91
+ if TYPE_CHECKING:
92
+ # Have `LazyDict[T]` behave as `T`, so that attribute access works. Ideally, it
93
+ # would be a subclass of `T`, but this doesn't seem to be possible in the type
94
+ # system yet.
95
+ LazyDict: TypeAlias = T
96
+ else:
97
+ LazyDict = DictConfig
98
+
99
+
100
+ class LazyCall(Generic[T]):
101
+ """
102
+ Wrap a callable so that when it's called, the call will not be executed,
103
+ but returns a dict that describes the call.
104
+
105
+ LazyCall object has to be called with only keyword arguments. Positional
106
+ arguments are not yet supported.
107
+
108
+ Examples:
109
+ ::
110
+ from detectron2.config import instantiate, LazyCall
111
+
112
+ layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
113
+ layer_cfg.out_channels = 64 # can edit it afterwards
114
+ layer = instantiate(layer_cfg)
115
+ """
116
+
117
+ def __init__(self, target: type[T]):
118
+ if not (callable(target) or isinstance(target, (str, abc.Mapping))):
119
+ raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
120
+ self._target = target
121
+
122
+ def __call__(self, **kwargs) -> LazyDict[T]:
123
+ if is_dataclass(self._target) or attrs.has(self._target):
124
+ # omegaconf object cannot hold dataclass type
125
+ # https://github.com/omry/omegaconf/issues/784
126
+ target = _convert_target_to_string(self._target)
127
+ else:
128
+ target = self._target
129
+ kwargs["_target_"] = target
130
+
131
+ _final_params = get_default_params(self._target)
132
+ _final_params.update(kwargs)
133
+
134
+ return cast(LazyDict[T], DictConfig(content=_final_params, flags={"allow_objects": True}))
135
+
136
+
137
+ def _visit_dict_config(cfg, func):
138
+ """
139
+ Apply func recursively to all DictConfig in cfg.
140
+ """
141
+ if isinstance(cfg, DictConfig):
142
+ func(cfg)
143
+ for v in cfg.values():
144
+ _visit_dict_config(v, func)
145
+ elif isinstance(cfg, ListConfig):
146
+ for v in cfg:
147
+ _visit_dict_config(v, func)
148
+
149
+
150
+ def _validate_py_syntax(filename):
151
+ # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
152
+ with PathManager.open(filename, "r") as f:
153
+ content = f.read()
154
+ try:
155
+ ast.parse(content)
156
+ except SyntaxError as e:
157
+ raise SyntaxError(f"Config file {filename} has syntax error!") from e
158
+
159
+
160
+ def _cast_to_config(obj):
161
+ # if given a dict, return DictConfig instead
162
+ if isinstance(obj, dict):
163
+ return DictConfig(obj, flags={"allow_objects": True})
164
+ return obj
165
+
166
+
167
+ _CFG_PACKAGE_NAME = "detectron2._cfg_loader"
168
+ """
169
+ A namespace to put all imported config into.
170
+ """
171
+
172
+
173
+ def _random_package_name(filename):
174
+ # generate a random package name when loading config files
175
+ return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
176
+
177
+
178
+ @contextmanager
179
+ def _patch_import():
180
+ """
181
+ Enhance relative import statements in config files, so that they:
182
+ 1. locate files purely based on relative location, regardless of packages.
183
+ e.g. you can import file without having __init__
184
+ 2. do not cache modules globally; modifications of module states has no side effect
185
+ 3. support other storage system through PathManager, so config files can be in the cloud
186
+ 4. imported dict are turned into omegaconf.DictConfig automatically
187
+ """
188
+ old_import = builtins.__import__
189
+
190
+ def find_relative_file(original_file, relative_import_path, level):
191
+ # NOTE: "from . import x" is not handled. Because then it's unclear
192
+ # if such import should produce `x` as a python module or DictConfig.
193
+ # This can be discussed further if needed.
194
+ relative_import_err = """
195
+ Relative import of directories is not allowed within config files.
196
+ Within a config file, relative import can only import other config files.
197
+ """.replace("\n", " ")
198
+ if not len(relative_import_path):
199
+ raise ImportError(relative_import_err)
200
+
201
+ cur_file = os.path.dirname(original_file)
202
+ for _ in range(level - 1):
203
+ cur_file = os.path.dirname(cur_file)
204
+ cur_name = relative_import_path.lstrip(".")
205
+ for part in cur_name.split("."):
206
+ cur_file = os.path.join(cur_file, part)
207
+ if not cur_file.endswith(".py"):
208
+ cur_file += ".py"
209
+ if not PathManager.isfile(cur_file):
210
+ cur_file_no_suffix = cur_file[: -len(".py")]
211
+ if PathManager.isdir(cur_file_no_suffix):
212
+ raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
213
+ else:
214
+ raise ImportError(
215
+ f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
216
+ )
217
+ return cur_file
218
+
219
+ def new_import(name, globals=None, locals=None, fromlist=(), level=0):
220
+ if (
221
+ # Only deal with relative imports inside config files
222
+ level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
223
+ ):
224
+ cur_file = find_relative_file(globals["__file__"], name, level)
225
+ _validate_py_syntax(cur_file)
226
+ spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
227
+ module = importlib.util.module_from_spec(spec)
228
+ module.__file__ = cur_file
229
+ with PathManager.open(cur_file) as f:
230
+ content = f.read()
231
+ exec(compile(content, cur_file, "exec"), module.__dict__)
232
+ for name in fromlist: # turn imported dict into DictConfig automatically
233
+ val = _cast_to_config(module.__dict__[name])
234
+ module.__dict__[name] = val
235
+ return module
236
+ return old_import(name, globals, locals, fromlist=fromlist, level=level)
237
+
238
+ builtins.__import__ = new_import
239
+ yield new_import
240
+ builtins.__import__ = old_import
241
+
242
+
243
+ class LazyConfig:
244
+ """
245
+ Provide methods to save, load, and overrides an omegaconf config object
246
+ which may contain definition of lazily-constructed objects.
247
+ """
248
+
249
+ @staticmethod
250
+ def load_rel(filename: str, keys: None | str | tuple[str, ...] = None):
251
+ """
252
+ Similar to :meth:`load()`, but load path relative to the caller's
253
+ source file.
254
+
255
+ This has the same functionality as a relative import, except that this method
256
+ accepts filename as a string, so more characters are allowed in the filename.
257
+ """
258
+ caller_frame = inspect.stack()[1]
259
+ caller_fname = caller_frame[0].f_code.co_filename
260
+ assert caller_fname != "<string>", "load_rel Unable to find caller"
261
+ caller_dir = os.path.dirname(caller_fname)
262
+ filename = os.path.join(caller_dir, filename)
263
+ return LazyConfig.load(filename, keys)
264
+
265
+ @staticmethod
266
+ def load(filename: str, keys: None | str | tuple[str, ...] = None):
267
+ """
268
+ Load a config file.
269
+
270
+ Args:
271
+ filename: absolute path or relative path w.r.t. the current working directory
272
+ keys: keys to load and return. If not given, return all keys
273
+ (whose values are config objects) in a dict.
274
+ """
275
+ has_keys = keys is not None
276
+ filename = filename.replace("/./", "/") # redundant
277
+ if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
278
+ raise ValueError(f"Config file {filename} has to be a python or yaml file.")
279
+ if filename.endswith(".py"):
280
+ _validate_py_syntax(filename)
281
+
282
+ with _patch_import():
283
+ # Record the filename
284
+ module_namespace = {
285
+ "__file__": filename,
286
+ "__package__": _random_package_name(filename),
287
+ }
288
+ with PathManager.open(filename) as f:
289
+ content = f.read()
290
+ # Compile first with filename to:
291
+ # 1. make filename appears in stacktrace
292
+ # 2. make load_rel able to find its parent's (possibly remote) location
293
+ exec(compile(content, filename, "exec"), module_namespace)
294
+
295
+ ret = module_namespace
296
+ else:
297
+ with PathManager.open(filename) as f:
298
+ obj = yaml.unsafe_load(f)
299
+ ret = OmegaConf.create(obj, flags={"allow_objects": True})
300
+
301
+ if has_keys:
302
+ if isinstance(keys, str):
303
+ return _cast_to_config(ret[keys])
304
+ else:
305
+ return tuple(_cast_to_config(ret[a]) for a in keys)
306
+ else:
307
+ if filename.endswith(".py"):
308
+ # when not specified, only load those that are config objects
309
+ ret = DictConfig(
310
+ {
311
+ name: _cast_to_config(value)
312
+ for name, value in ret.items()
313
+ if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
314
+ },
315
+ flags={"allow_objects": True},
316
+ )
317
+ return ret
318
+
319
+ @staticmethod
320
+ def save_pkl(cfg, filename: str) -> str:
321
+ """
322
+ Saves a Config object to a file using pickle serialization. This method is typically used
323
+ when the configuration object contains complex objects, such as lambdas, that are not supported by
324
+ simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
325
+ object before serialization to ensure that the original object remains unmodified.
326
+
327
+ Args:
328
+ cfg: A Config object to be serialized and saved.
329
+ filename: The path and name of the file where the configuration should be saved. The function
330
+ assumes the file extension indicates a pickle format (e.g., .pkl).
331
+
332
+ Returns:
333
+ str: The filename to which the configuration was saved. This can be used to verify the file location
334
+ or log the outcome.
335
+
336
+ Notes:
337
+ - The function logs a warning if the configuration is successfully saved using pickle.
338
+ - If saving fails, an error is logged with the exception details.
339
+ """
340
+ try:
341
+ cfg = deepcopy(cfg)
342
+ except Exception:
343
+ pass
344
+
345
+ try:
346
+ with PathManager.open(filename, "wb") as f:
347
+ pickle.dump(cfg, f)
348
+ log.warning(f"Config is saved using pickle at {filename}.")
349
+ except Exception as e:
350
+ log.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
351
+ if dill_pickle:
352
+ try:
353
+ with PathManager.open(filename, "wb") as f:
354
+ pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
355
+ log.warning(f"Config is saved using dill at {filename}.")
356
+ except Exception as e:
357
+ log.error(f"Failed to save config to {filename}: {e}.")
358
+ if cloudpickle:
359
+ try:
360
+ with PathManager.open(filename, "wb") as f:
361
+ pickle.dump(cloudpickle.dumps(cfg), f)
362
+ log.warning(f"Config is saved using cloudpickle at {filename}.")
363
+ except Exception as e:
364
+ log.error(f"Failed to save config to {filename}: {e}.")
365
+ else:
366
+ log.error("cloudpickle is not available. Cannot save the config.")
367
+ raise e
368
+
369
+ return filename
370
+
371
+ @staticmethod
372
+ def save_yaml(cfg, filename: str) -> str:
373
+ """
374
+ Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
375
+
376
+ Args:
377
+ cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
378
+ filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
379
+
380
+ Returns:
381
+ str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
382
+
383
+ Notes:
384
+ - The function logs a warning if the configuration is successfully saved using YAML.
385
+ - If saving fails, an error is logged with the exception details.
386
+ """
387
+ logger = logging.getLogger(__name__)
388
+ try:
389
+ cfg = deepcopy(cfg)
390
+ except Exception:
391
+ pass
392
+
393
+ # Define a function to check if an item is serializable to YAML
394
+ def is_serializable(item):
395
+ try:
396
+ OmegaConf.to_yaml(item)
397
+ return True
398
+ except Exception as e:
399
+ return False
400
+
401
+ # Function to convert unserializable items to strings
402
+ def serialize_config(config):
403
+ if isinstance(config, DictConfig):
404
+ for key, value in config.items():
405
+ if isinstance(value, (DictConfig, ListConfig)):
406
+ try:
407
+ if "_target_" in value:
408
+ default_params = get_default_params(value["_target_"])
409
+ for default_key, default_v in default_params.items():
410
+ if default_key not in value:
411
+ value[default_key] = default_v
412
+ except Exception as e:
413
+ log.error(f"Failed to add default argument values: {e}")
414
+
415
+ serialize_config(value)
416
+ else:
417
+ if not is_serializable(value) and value is not None:
418
+ config[key] = str(value)
419
+ elif isinstance(config, ListConfig):
420
+ for i, item in enumerate(config):
421
+ if isinstance(item, (DictConfig, ListConfig)):
422
+ serialize_config(item)
423
+ else:
424
+ if not is_serializable(item) and item is not None:
425
+ config[i] = str(item)
426
+ else:
427
+ raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
428
+ return config
429
+
430
+ # Convert Config object to a DictConfig object.
431
+ config_dict = attrs.asdict(cfg)
432
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
433
+
434
+ # Serialize the DictConfig object by converting non-serializable objects to strings.
435
+ config_omegaconf = serialize_config(config_omegaconf)
436
+
437
+ config_dict: dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
438
+ sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
439
+ with open(filename, "w") as f:
440
+ yaml.dump(sorted_config, f, default_flow_style=False)
441
+ log.warning(f"Config is saved using omegaconf at {filename}.")
442
+ return filename
imaginaire/lazy_config/omegaconf_patch.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ from omegaconf import OmegaConf
19
+ from omegaconf.base import DictKeyType, SCMode
20
+ from omegaconf.dictconfig import DictConfig # pragma: no cover
21
+
22
+
23
+ def to_object(cfg: Any) -> dict[DictKeyType, Any] | list[Any] | None | str | Any:
24
+ """
25
+ Converts an OmegaConf configuration object to a native Python container (dict or list), unless
26
+ the configuration is specifically created by LazyCall, in which case the original configuration
27
+ is returned directly.
28
+
29
+ This function serves as a modification of the original `to_object` method from OmegaConf,
30
+ preventing DictConfig objects created by LazyCall from being automatically converted to Python
31
+ dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
32
+ structure and behavior.
33
+
34
+ Differences from OmegaConf's original `to_object`:
35
+ - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
36
+
37
+ Reference:
38
+ - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
39
+
40
+ Args:
41
+ cfg (Any): The OmegaConf configuration object to convert.
42
+
43
+ Returns:
44
+ Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
45
+ `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
46
+
47
+ Examples:
48
+ >>> cfg = DictConfig({"key": "value", "_target_": "Model"})
49
+ >>> to_object(cfg)
50
+ DictConfig({"key": "value", "_target_": "Model"})
51
+
52
+ >>> cfg = DictConfig({"list": [1, 2, 3]})
53
+ >>> to_object(cfg)
54
+ {'list': [1, 2, 3]}
55
+ """
56
+ if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
57
+ return cfg
58
+
59
+ return OmegaConf.to_container(
60
+ cfg=cfg,
61
+ resolve=True,
62
+ throw_on_missing=True,
63
+ enum_to_str=False,
64
+ structured_config_mode=SCMode.INSTANTIATE,
65
+ )
imaginaire/lazy_config/registry.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import pydoc
17
+ from typing import Any
18
+
19
+ from fvcore.common.registry import Registry # for backward compatibility.
20
+
21
+ """
22
+ ``Registry`` and `locate` provide ways to map a string (typically found
23
+ in config files) to callable objects.
24
+ """
25
+
26
+ __all__ = ["Registry", "locate"]
27
+
28
+
29
+ def _convert_target_to_string(t: Any) -> str:
30
+ """
31
+ Inverse of ``locate()``.
32
+
33
+ Args:
34
+ t: any object with ``__module__`` and ``__qualname__``
35
+ """
36
+ module, qualname = t.__module__, t.__qualname__
37
+
38
+ # Compress the path to this object, e.g. ``module.submodule._impl.class``
39
+ # may become ``module.submodule.class``, if the later also resolves to the same
40
+ # object. This simplifies the string, and also is less affected by moving the
41
+ # class implementation.
42
+ module_parts = module.split(".")
43
+ for k in range(1, len(module_parts)):
44
+ prefix = ".".join(module_parts[:k])
45
+ candidate = f"{prefix}.{qualname}"
46
+ try:
47
+ if locate(candidate) is t:
48
+ return candidate
49
+ except ImportError:
50
+ pass
51
+ return f"{module}.{qualname}"
52
+
53
+
54
+ def locate(name: str) -> Any:
55
+ """
56
+ Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
57
+ such as "module.submodule.class_name".
58
+
59
+ Raise Exception if it cannot be found.
60
+ """
61
+ obj = pydoc.locate(name)
62
+
63
+ # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
64
+ # by pydoc.locate. Try a private function from hydra.
65
+ if obj is None:
66
+ try:
67
+ # from hydra.utils import get_method - will print many errors
68
+ from hydra.utils import _locate
69
+ except ImportError as e:
70
+ raise ImportError(f"Cannot dynamically locate object {name}!") from e
71
+ else:
72
+ obj = _locate(name) # it raises if fails
73
+
74
+ return obj
imaginaire/model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import torch
19
+
20
+ from imaginaire.lazy_config import LazyDict, instantiate
21
+
22
+
23
+ class ImaginaireModel(torch.nn.Module):
24
+ """The base model class of Imaginaire. It is inherited from torch.nn.Module.
25
+
26
+ All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
27
+ computation graphs. All inheriting child classes should implement the following methods:
28
+ - training_step(): The training step of the model, including the loss computation.
29
+ - validation_step(): The validation step of the model, including the loss computation.
30
+ - forward(): The computation graph for model inference.
31
+ The following methods have default implementations in ImaginaireModel:
32
+ - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ super().__init__()
37
+
38
+ def init_optimizer_scheduler(
39
+ self,
40
+ optimizer_config: LazyDict[torch.optim.Optimizer],
41
+ scheduler_config: LazyDict[torch.optim.lr_scheduler.LRScheduler],
42
+ ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
43
+ """Creates the optimizer and scheduler for the model.
44
+
45
+ Args:
46
+ config_model (ModelConfig): The config object for the model.
47
+
48
+ Returns:
49
+ optimizer (torch.optim.Optimizer): The model optimizer.
50
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
51
+ """
52
+ optimizer_config.params = self.parameters()
53
+ optimizer = instantiate(optimizer_config)
54
+ scheduler_config.optimizer = optimizer
55
+ scheduler = instantiate(scheduler_config)
56
+ return optimizer, scheduler
57
+
58
+ def training_step(
59
+ self, data_batch: dict[str, torch.Tensor], iteration: int
60
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
61
+ """The training step of the model, including the loss computation.
62
+
63
+ Args:
64
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
65
+ iteration (int): Current iteration number.
66
+
67
+ Returns:
68
+ output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
69
+ loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
70
+ """
71
+ raise NotImplementedError
72
+
73
+ @torch.no_grad()
74
+ def validation_step(
75
+ self, data_batch: dict[str, torch.Tensor], iteration: int
76
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
77
+ """The validation step of the model, including the loss computation.
78
+
79
+ Args:
80
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
81
+ iteration (int): Current iteration number.
82
+
83
+ Returns:
84
+ output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
85
+ loss (torch.Tensor): The total loss (weighted sum of various losses).
86
+ """
87
+ raise NotImplementedError
88
+
89
+ @torch.inference_mode()
90
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
91
+ """The computation graph for model inference.
92
+
93
+ Args:
94
+ *args: Whatever you decide to pass into the forward method.
95
+ **kwargs: Keyword arguments are also possible.
96
+
97
+ Return:
98
+ Your model's output.
99
+ """
100
+ raise NotImplementedError
101
+
102
+ def on_model_init_start(self, set_barrier=False) -> None:
103
+ return
104
+
105
+ def on_model_init_end(self, set_barrier=False) -> None:
106
+ return
107
+
108
+ def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
109
+ """The model preparation before the training is launched
110
+
111
+ Args:
112
+ memory_format (torch.memory_format): Memory format of the model.
113
+ """
114
+ pass
115
+
116
+ def on_before_zero_grad(
117
+ self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
118
+ ) -> None:
119
+ """Hook before zero_grad() is called.
120
+
121
+ Args:
122
+ optimizer (torch.optim.Optimizer): The model optimizer.
123
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
124
+ iteration (int): Current iteration number.
125
+ """
126
+ pass
127
+
128
+ def on_after_backward(self, iteration: int = 0) -> None:
129
+ """Hook after loss.backward() is called.
130
+
131
+ This method is called immediately after the backward pass, allowing for custom operations
132
+ or modifications to be performed on the gradients before the optimizer step.
133
+
134
+ Args:
135
+ iteration (int): Current iteration number.
136
+ """
137
+ pass
imaginaire/trainer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import os
19
+ import signal
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch.utils.data
24
+
25
+ from imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
26
+
27
+ try:
28
+ from megatron.core import parallel_state
29
+
30
+ USE_MEGATRON = True
31
+ except ImportError:
32
+ USE_MEGATRON = False
33
+ print("Megatron-core is not installed.")
34
+
35
+
36
+ from imaginaire.lazy_config import LazyConfig, instantiate
37
+ from imaginaire.model import ImaginaireModel
38
+ from imaginaire.utils import callback, distributed, log, misc
39
+ from imaginaire.utils.checkpointer import Checkpointer
40
+
41
+
42
+ class ImaginaireTrainer:
43
+ """The base trainer class of Imaginaire.
44
+
45
+ All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
46
+ (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
47
+ mixed-precision training (fp16/bf16).
48
+
49
+ Attributes:
50
+ checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
51
+ training_timer (misc.Timer): Timer object to time code blocks and functions.
52
+ """
53
+
54
+ def __init__(self, config):
55
+ """Constructor of the trainer.
56
+
57
+ Args:
58
+ config (Config): The config object for the Imaginaire codebase.
59
+ """
60
+ super().__init__()
61
+ self.config = config
62
+ # Set up the distributed computing environment.
63
+ with misc.timer("init_distributed"):
64
+ distributed.init()
65
+ # Set up parallel states.
66
+ if hasattr(config.model, "context_parallel_size"):
67
+ if config.model_parallel.context_parallel_size > 1:
68
+ raise ValueError(
69
+ "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
70
+ "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
71
+ )
72
+ else:
73
+ log.critical(
74
+ "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
75
+ )
76
+ config.model_parallel.context_parallel_size = config.model.context_parallel_size
77
+ if USE_MEGATRON:
78
+ if (
79
+ "create_gloo_process_groups"
80
+ in inspect.signature(parallel_state.initialize_model_parallel).parameters
81
+ ):
82
+ parallel_state.initialize_model_parallel(
83
+ pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
84
+ tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
85
+ context_parallel_size=config.model_parallel.context_parallel_size,
86
+ create_gloo_process_groups=False,
87
+ )
88
+ else:
89
+ parallel_state.initialize_model_parallel(
90
+ pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
91
+ tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
92
+ context_parallel_size=config.model_parallel.context_parallel_size,
93
+ )
94
+ # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
95
+ # It is not part of the original `parallel_state` API, so we need to set it manually.
96
+ parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
97
+ if parallel_state.sequence_parallel:
98
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
99
+
100
+ # Create the local job directory, save the config file, and pipe to a local log.
101
+ if distributed.is_rank0():
102
+ os.makedirs(config.job.path_local, exist_ok=True)
103
+ # Save the config as .pkl for reproducibility.
104
+ LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
105
+ # Save the config as .yaml for reading or parsing experiment hyperparameters.
106
+ LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
107
+ dist.barrier()
108
+ log.init_loguru_file(f"{config.job.path_local}/stdout.log")
109
+ if distributed.is_rank0():
110
+ # Print important environment variables and the effective config.
111
+ log.info("Config:\n" + config.pretty_print(use_color=True))
112
+ misc.print_environ_variables(["TORCH_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
113
+ # Set the random seed. If multi-GPU, different ranks are set with different seeds.
114
+ misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
115
+ # Initialize cuDNN.
116
+ torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
117
+ torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
118
+ # Floating-point precision settings.
119
+ torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
120
+ # Initialize the callback functions.
121
+ self.callbacks = callback.CallBackGroup(config=config, trainer=self)
122
+ # Initialize the model checkpointer.
123
+ if config.checkpoint.type is None:
124
+ self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
125
+ else:
126
+ self.checkpointer: Checkpointer = instantiate(
127
+ config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
128
+ )
129
+ # Initialize the timer for speed benchmarking.
130
+ self.training_timer = misc.TrainingTimer()
131
+ # Send a TimeoutError if a training step takes over timeout_period seconds.
132
+ signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
133
+
134
+ def train(
135
+ self,
136
+ model: ImaginaireModel,
137
+ dataloader_train: torch.utils.data.DataLoader,
138
+ dataloader_val: torch.utils.data.DataLoader,
139
+ ) -> None:
140
+ """The training function.
141
+
142
+ Args:
143
+ model (ImaginaireModel): The PyTorch model.
144
+ dataloader_train (torch.utils.data.DataLoader): The training data loader.
145
+ dataloader_val (torch.utils.data.DataLoader): The validation data loader.
146
+ """
147
+ # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
148
+ model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
149
+ model.on_train_start(self.config.trainer.memory_format)
150
+
151
+ # Initialize the optimizer, scheduler, and grad_scaler.
152
+ self.callbacks.on_optimizer_init_start()
153
+ optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
154
+ grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
155
+ self.callbacks.on_optimizer_init_end()
156
+ # Load the model checkpoint and get the starting iteration number.
157
+ iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
158
+ grad_accum_iter = 0
159
+ log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
160
+ if self.config.trainer.distributed_parallelism == "ddp":
161
+ # Create a DDP model wrapper.
162
+ model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
163
+ elif self.config.trainer.distributed_parallelism == "fsdp":
164
+ model_ddp = model
165
+ else:
166
+ raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
167
+ log.info("Starting training...")
168
+ self.callbacks.on_train_start(model, iteration=iteration)
169
+ # Initial validation.
170
+ if self.config.trainer.run_validation and iteration == 0:
171
+ self.validate(model, dataloader_val, iteration=iteration)
172
+ log.info("Initial validation done.")
173
+ _end_training = False
174
+ with (
175
+ maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
176
+ maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
177
+ ):
178
+ while True:
179
+ dataloader_train_iter = iter(dataloader_train)
180
+ while True:
181
+ self.callbacks.on_before_dataloading(iteration)
182
+ try:
183
+ with self.training_timer("dataloader_train"):
184
+ data_batch = next(dataloader_train_iter)
185
+ except StopIteration:
186
+ break
187
+ finally:
188
+ self.callbacks.on_after_dataloading(iteration)
189
+ # If max_iter is reached, exit the training loop.
190
+ if iteration >= self.config.trainer.max_iter:
191
+ _end_training = True
192
+ break
193
+ # Move all tensors in the data batch to GPU device.
194
+ data_batch = misc.to(data_batch, device="cuda")
195
+ # The actual training step.
196
+ self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
197
+ self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
198
+ if not model.training:
199
+ model_ddp.train()
200
+ assert model_ddp.training, "model_ddp is not in training mode."
201
+ assert model.training, "model is not in training mode."
202
+ output_batch, loss, grad_accum_iter = self.training_step(
203
+ model_ddp,
204
+ optimizer,
205
+ scheduler,
206
+ grad_scaler,
207
+ data_batch,
208
+ iteration=iteration,
209
+ grad_accum_iter=grad_accum_iter,
210
+ )
211
+ self.callbacks.on_training_step_batch_end(
212
+ model, data_batch, output_batch, loss, iteration=iteration
213
+ )
214
+ # If the gradients are still being accumulated, continue to load the next training batch.
215
+ if grad_accum_iter != 0:
216
+ continue
217
+ # Do the following when an actual optimizer (update) step has been made.
218
+ iteration += 1
219
+ # Save checkpoint.
220
+ if iteration % self.config.checkpoint.save_iter == 0:
221
+ self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
222
+ self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
223
+ # Validation.
224
+ if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
225
+ self.validate(model, dataloader_val, iteration=iteration)
226
+ # This iteration is successful; reset the timeout signal.
227
+ signal.alarm(self.config.trainer.timeout_period)
228
+ if torch_profiler:
229
+ torch_profiler.step()
230
+ if memory_profiler:
231
+ memory_profiler.step()
232
+ if _end_training:
233
+ break
234
+ log.success("Done with training.")
235
+ if iteration % self.config.checkpoint.save_iter != 0:
236
+ self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
237
+ self.callbacks.on_train_end(model, iteration=iteration)
238
+ self.checkpointer.finalize()
239
+ distributed.barrier()
240
+ self.callbacks.on_app_end()
241
+
242
+ def training_step(
243
+ self,
244
+ model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
245
+ optimizer: torch.optim.Optimizer,
246
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
247
+ grad_scaler: torch.amp.GradScaler,
248
+ data: dict[str, torch.Tensor],
249
+ iteration: int = 0,
250
+ grad_accum_iter: int = 0,
251
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
252
+ """The training step.
253
+
254
+ Args:
255
+ model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
256
+ module, depending on whether distributed training is enabled or not.
257
+ optimizer (torch.optim.Optimizer): The model optimizer.
258
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
259
+ grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
260
+ data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
261
+ iteration (int): Current iteration number.
262
+ grad_accum_iter (int): Number of gradient accumulation iterations.
263
+
264
+ Returns:
265
+ output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
266
+ loss (torch.Tensor): The total loss of the training data batch.
267
+ """
268
+ # Only let DDP sync gradient at the last iteration of the gradient accumulation window
269
+ with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
270
+ self.callbacks.on_before_forward(iteration=iteration)
271
+ with self.training_timer("forward"):
272
+ output_batch, loss = model_ddp.training_step(data, iteration)
273
+ self.callbacks.on_after_forward(iteration=iteration)
274
+ self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
275
+ with self.training_timer("backward"):
276
+ loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
277
+ loss_scaled.backward()
278
+ if self.config.trainer.distributed_parallelism == "ddp":
279
+ model_ddp.module.on_after_backward()
280
+ else:
281
+ model_ddp.on_after_backward()
282
+ self.callbacks.on_after_backward(model_ddp, iteration=iteration)
283
+ grad_accum_iter += 1
284
+ if grad_accum_iter == self.config.trainer.grad_accum_iter:
285
+ with self.training_timer("optimizer_step"):
286
+ self.callbacks.on_before_optimizer_step(
287
+ model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
288
+ )
289
+ grad_scaler.step(optimizer)
290
+ grad_scaler.update()
291
+ scheduler.step()
292
+ self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
293
+ if self.config.trainer.distributed_parallelism == "ddp":
294
+ model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
295
+ else:
296
+ model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
297
+ optimizer.zero_grad(set_to_none=True)
298
+ grad_accum_iter = 0
299
+ return output_batch, loss, grad_accum_iter
300
+
301
+ @torch.no_grad()
302
+ def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
303
+ """Validate on the full validation dataset.
304
+
305
+ Args:
306
+ model (ImaginaireModel): The PyTorch model.
307
+ dataloader_val (torch.utils.data.DataLoader): The validation data loader.
308
+ iteration (int): Current iteration number.
309
+ """
310
+ log.info(f"Validating at iteration {iteration}...")
311
+ self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
312
+ model.eval()
313
+ # Evaluate on the full validation set.
314
+ with model.pipe.ema_scope(context="Validation", is_cpu=False):
315
+ for val_iter, data_batch in enumerate(dataloader_val):
316
+ if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
317
+ break
318
+ data_batch = misc.to(data_batch, device="cuda")
319
+ self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
320
+ output_batch, loss = model.validation_step(data_batch, iteration)
321
+ self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
322
+ self.callbacks.on_validation_end(model, iteration=iteration)
imaginaire/utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
imaginaire/utils/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
imaginaire/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (151 Bytes). View file
 
imaginaire/utils/__pycache__/device.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
imaginaire/utils/__pycache__/distributed.cpython-310.pyc ADDED
Binary file (15 kB). View file
 
imaginaire/utils/__pycache__/io.cpython-310.pyc ADDED
Binary file (4.93 kB). View file
 
imaginaire/utils/__pycache__/io.cpython-39.pyc ADDED
Binary file (4.9 kB). View file
 
imaginaire/utils/__pycache__/log.cpython-310.pyc ADDED
Binary file (4.01 kB). View file
 
imaginaire/utils/__pycache__/log.cpython-39.pyc ADDED
Binary file (4.25 kB). View file
 
imaginaire/utils/__pycache__/misc.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
imaginaire/utils/callback.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import time
19
+ import warnings
20
+ from collections.abc import Callable
21
+ from typing import TYPE_CHECKING, Any
22
+
23
+ import omegaconf
24
+ import torch
25
+ import torch.utils.data
26
+ import tqdm
27
+
28
+ from imaginaire.lazy_config import instantiate
29
+ from imaginaire.utils import distributed, log
30
+ from imaginaire.utils.misc import get_local_tensor_if_DTensor
31
+
32
+ try:
33
+ from megatron.core import parallel_state
34
+ except ImportError:
35
+ parallel_state = None
36
+ print("Megatron-core is not installed.")
37
+
38
+
39
+ if TYPE_CHECKING:
40
+ from imaginaire.config import Config
41
+ from imaginaire.model import ImaginaireModel
42
+ from imaginaire.trainer import ImaginaireTrainer
43
+
44
+
45
+ class CallBackGroup:
46
+ """A class for hosting a collection of callback objects.
47
+
48
+ It is used to execute callback functions of multiple callback objects with the same method name.
49
+ When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
50
+ self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
51
+
52
+ Attributes:
53
+ _callbacks (list[Callback]): List of callback objects.
54
+ """
55
+
56
+ def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
57
+ """Initializes the list of callback objects.
58
+
59
+ Args:
60
+ config (Config): The config object for the Imaginaire codebase.
61
+ trainer (ImaginaireTrainer): The main trainer.
62
+ """
63
+ self._callbacks = []
64
+ callback_configs = config.trainer.callbacks
65
+ if callback_configs:
66
+ if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
67
+ warnings.warn(
68
+ "The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
69
+ "Please update your code",
70
+ DeprecationWarning,
71
+ stacklevel=2,
72
+ )
73
+ callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
74
+ for callback_name, current_callback_cfg in callback_configs.items():
75
+ if "_target_" not in current_callback_cfg:
76
+ log.critical(
77
+ f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
78
+ )
79
+ continue
80
+ log.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
81
+ _callback = instantiate(current_callback_cfg)
82
+ assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
83
+ _callback.config = config
84
+ _callback.trainer = trainer
85
+ self._callbacks.append(_callback)
86
+
87
+ def __getattr__(self, method_name: str) -> Callable:
88
+ """Loops through the callback objects to call the corresponding callback function.
89
+
90
+ Args:
91
+ method_name (str): Callback method name.
92
+ """
93
+
94
+ def multi_callback_wrapper(*args, **kwargs) -> None:
95
+ for callback in self._callbacks:
96
+ assert hasattr(callback, method_name)
97
+ method = getattr(callback, method_name)
98
+ assert callable(method)
99
+ _ = method(*args, **kwargs)
100
+
101
+ return multi_callback_wrapper
102
+
103
+
104
+ class Callback:
105
+ """The base class for all callbacks.
106
+
107
+ All callbacks should inherit from this class and adhere to the established method names and signatures.
108
+ """
109
+
110
+ def __init__(self, config: Config | None = None, trainer: ImaginaireTrainer | None = None):
111
+ """Initializes a Callback object.
112
+
113
+ Args:
114
+ config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
115
+ trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
116
+
117
+ Notes:
118
+ The config and trainer parameters are optional to maintain backward compatibility.
119
+ In future releases, these parameters will be removed. Upon using these parameters, a deprecation
120
+ warning will be issued.
121
+
122
+ """
123
+ if config is not None or trainer is not None:
124
+ warnings.warn(
125
+ "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
126
+ "Please update your code to create Callback instances without these parameters.",
127
+ DeprecationWarning,
128
+ stacklevel=2,
129
+ )
130
+ del config, trainer
131
+
132
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
133
+ pass
134
+
135
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
136
+ """
137
+ Called before the training step, for each batch. This is paired with on_training_step_end() but note that
138
+ when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
139
+ this function is called for every batch.
140
+ Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
141
+ for every batch, albeit with the same iteration number.
142
+ """
143
+ pass
144
+
145
+ def on_training_step_batch_start(
146
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
147
+ ) -> None:
148
+ """
149
+ Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
150
+ on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
151
+ Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
152
+ """
153
+ pass
154
+
155
+ def on_before_forward(self, iteration: int = 0) -> None:
156
+ pass
157
+
158
+ def on_after_forward(self, iteration: int = 0) -> None:
159
+ pass
160
+
161
+ def on_before_backward(
162
+ self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
163
+ ) -> None:
164
+ pass
165
+
166
+ def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
167
+ pass
168
+
169
+ def on_before_dataloading(self, iteration: int = 0) -> None:
170
+ pass
171
+
172
+ def on_after_dataloading(self, iteration: int = 0) -> None:
173
+ pass
174
+
175
+ def on_optimizer_init_start(self) -> None:
176
+ pass
177
+
178
+ def on_optimizer_init_end(self) -> None:
179
+ pass
180
+
181
+ def on_before_optimizer_step(
182
+ self,
183
+ model_ddp: distributed.DistributedDataParallel,
184
+ optimizer: torch.optim.Optimizer,
185
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
186
+ grad_scaler: torch.amp.GradScaler,
187
+ iteration: int = 0,
188
+ ) -> None:
189
+ pass
190
+
191
+ def on_before_zero_grad(
192
+ self,
193
+ model_ddp: distributed.DistributedDataParallel,
194
+ optimizer: torch.optim.Optimizer,
195
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
196
+ iteration: int = 0,
197
+ ) -> None:
198
+ pass
199
+
200
+ def on_training_step_batch_end(
201
+ self,
202
+ model: ImaginaireModel,
203
+ data_batch: dict[str, torch.Tensor],
204
+ output_batch: dict[str, torch.Tensor],
205
+ loss: torch.Tensor,
206
+ iteration: int = 0,
207
+ ) -> None:
208
+ """
209
+ Called at the end of a training step for every batch even when using gradient accumulation.
210
+ This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
211
+ and therefore it may be the same for multiple batches.
212
+ """
213
+ pass
214
+
215
+ def on_training_step_end(
216
+ self,
217
+ model: ImaginaireModel,
218
+ data_batch: dict[str, torch.Tensor],
219
+ output_batch: dict[str, torch.Tensor],
220
+ loss: torch.Tensor,
221
+ iteration: int = 0,
222
+ ) -> None:
223
+ """
224
+ Called at the end of a training step, but note that when using gradient accumulation, this is only called
225
+ when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
226
+ Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
227
+ for every batch.
228
+ """
229
+ pass
230
+
231
+ def on_validation_start(
232
+ self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
233
+ ) -> None:
234
+ pass
235
+
236
+ def on_validation_step_start(
237
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
238
+ ) -> None:
239
+ pass
240
+
241
+ def on_validation_step_end(
242
+ self,
243
+ model: ImaginaireModel,
244
+ data_batch: dict[str, torch.Tensor],
245
+ output_batch: dict[str, torch.Tensor],
246
+ loss: torch.Tensor,
247
+ iteration: int = 0,
248
+ ) -> None:
249
+ pass
250
+
251
+ def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
252
+ pass
253
+
254
+ def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
255
+ pass
256
+
257
+ def on_load_checkpoint_end(
258
+ self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: str | None = None
259
+ ) -> None:
260
+ pass
261
+
262
+ def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
263
+ pass
264
+
265
+ def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
266
+ """
267
+ Called when checkpoint saving is about to start.
268
+ """
269
+ pass
270
+
271
+ def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
272
+ """
273
+ Called when the synchronous part of checkpointing is finished, this function can be used
274
+ along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
275
+ Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
276
+ does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
277
+ for that.
278
+ """
279
+ pass
280
+
281
+ def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
282
+ """
283
+ Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
284
+ For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
285
+ checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
286
+ checkpointing, this function is called as soon as the notification is received from the checkpointer process,
287
+ which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
288
+ the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
289
+ as this would be a significant overestimate.
290
+ """
291
+ pass
292
+
293
+ def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
294
+ pass
295
+
296
+ def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
297
+ pass
298
+
299
+ def on_app_end(self) -> None:
300
+ pass
301
+
302
+
303
+ class EMAModelCallback(Callback):
304
+ """The callback class for tracking EMA model weights."""
305
+
306
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
307
+ # Set up the EMA model weight tracker.
308
+ if model.config.ema.enabled:
309
+ assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
310
+ # EMA model must be kept in FP32 precision.
311
+ model.ema = model.ema.to(dtype=torch.float32)
312
+ else:
313
+ assert not hasattr(model, "ema"), "There should be no EMA initialized."
314
+
315
+ def on_training_step_end(
316
+ self,
317
+ model: ImaginaireModel,
318
+ data_batch: dict[str, torch.Tensor],
319
+ output_batch: dict[str, torch.Tensor],
320
+ loss: torch.Tensor,
321
+ iteration: int = 0,
322
+ ) -> None:
323
+ # Update the EMA model with the new regular weights.
324
+ if model.config.ema.enabled:
325
+ model.ema.update_average(model, iteration)
326
+
327
+
328
+ class ProgressBarCallback(Callback):
329
+ """The callback class for visualizing the training/validation progress bar in the console."""
330
+
331
+ @distributed.rank0_only
332
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
333
+ self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
334
+
335
+ @distributed.rank0_only
336
+ def on_training_step_end(
337
+ self,
338
+ model: ImaginaireModel,
339
+ data_batch: dict[str, torch.Tensor],
340
+ output_batch: dict[str, torch.Tensor],
341
+ loss: torch.Tensor,
342
+ iteration: int = 0,
343
+ ) -> None:
344
+ self.train_pbar.update()
345
+
346
+ @distributed.rank0_only
347
+ def on_validation_start(
348
+ self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
349
+ ) -> None:
350
+ if self.config.trainer.max_val_iter is not None:
351
+ num_iter = self.config.trainer.max_val_iter
352
+ else:
353
+ num_iter = len(dataloader_val)
354
+ assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
355
+ self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
356
+
357
+ @distributed.rank0_only
358
+ def on_validation_step_end(
359
+ self,
360
+ model: ImaginaireModel,
361
+ data_batch: dict[str, torch.Tensor],
362
+ output_batch: dict[str, torch.Tensor],
363
+ loss: torch.Tensor,
364
+ iteration: int = 0,
365
+ ) -> None:
366
+ self.val_pbar.update()
367
+
368
+ @distributed.rank0_only
369
+ def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
370
+ self.val_pbar.close()
371
+
372
+ @distributed.rank0_only
373
+ def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
374
+ self.trainer.checkpointer.finalize()
375
+ self.train_pbar.close()
376
+
377
+
378
+ class IterationLoggerCallback(Callback):
379
+ """The callback class for visualizing the training/validation progress bar in the console."""
380
+
381
+ @distributed.rank0_only
382
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
383
+ # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
384
+ self.start_iteration_time = time.time()
385
+ self.elapsed_iteration_time = 0
386
+
387
+ @distributed.rank0_only
388
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
389
+ self.start_iteration_time = time.time()
390
+
391
+ @distributed.rank0_only
392
+ def on_training_step_end(
393
+ self,
394
+ model: ImaginaireModel,
395
+ data_batch: dict[str, torch.Tensor],
396
+ output_batch: dict[str, torch.Tensor],
397
+ loss: torch.Tensor,
398
+ iteration: int = 0,
399
+ ) -> None:
400
+ self.elapsed_iteration_time += time.time() - self.start_iteration_time
401
+
402
+ if iteration % self.config.trainer.logging_iter == 0:
403
+ avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
404
+ log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
405
+
406
+ self.elapsed_iteration_time = 0
407
+
408
+
409
+ class LowPrecisionCallback(Callback):
410
+ """The callback class handling low precision training
411
+
412
+ Config with non-primitive type makes it difficult to override the option.
413
+ The callback gets precision from model.precision instead.
414
+ It also auto disabled when using fp32.
415
+ """
416
+
417
+ def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
418
+ self.update_iter = update_iter
419
+
420
+ def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
421
+ assert model.precision in [
422
+ torch.bfloat16,
423
+ torch.float16,
424
+ torch.half,
425
+ ], "LowPrecisionCallback must use a low precision dtype."
426
+ self.precision_type = model.precision
427
+
428
+ def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
429
+ for k, v in data.items():
430
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
431
+ data[k] = v.to(dtype=self.precision_type)
432
+
433
+ def on_validation_step_start(
434
+ self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
435
+ ) -> None:
436
+ for k, v in data.items():
437
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
438
+ data[k] = v.to(dtype=self.precision_type)
439
+
440
+ def on_before_zero_grad(
441
+ self,
442
+ model_ddp: distributed.DistributedDataParallel,
443
+ optimizer: torch.optim.Optimizer,
444
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
445
+ iteration: int = 0,
446
+ ) -> None:
447
+ if iteration % self.update_iter == 0:
448
+ if getattr(optimizer, "master_weights", False):
449
+ params, master_params = [], []
450
+ for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master, strict=False):
451
+ for p, p_master in zip(group["params"], group_master["params"], strict=False):
452
+ params.append(get_local_tensor_if_DTensor(p.data))
453
+ master_params.append(p_master.data)
454
+ torch._foreach_copy_(params, master_params)
455
+
456
+
457
+ class NVTXCallback(Callback):
458
+ """The callback for creating NVTX ranges"""
459
+
460
+ def __init__(
461
+ self,
462
+ synchronize: bool = False,
463
+ config: Config | None = None,
464
+ trainer: ImaginaireTrainer | None = None,
465
+ ):
466
+ super().__init__(config, trainer)
467
+ self.synchronize = synchronize
468
+
469
+ def on_before_forward(self, iteration: int = 0) -> None:
470
+ if self.synchronize:
471
+ torch.cuda.synchronize()
472
+ torch.cuda.nvtx.range_push("forward")
473
+
474
+ def on_after_forward(self, iteration: int = 0) -> None:
475
+ if self.synchronize:
476
+ torch.cuda.synchronize()
477
+ torch.cuda.nvtx.range_pop()
478
+
479
+ def on_before_backward(
480
+ self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
481
+ ) -> None:
482
+ if self.synchronize:
483
+ torch.cuda.synchronize()
484
+ torch.cuda.nvtx.range_push("backward")
485
+
486
+ def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
487
+ if self.synchronize:
488
+ torch.cuda.synchronize()
489
+ torch.cuda.nvtx.range_pop()
490
+
491
+ def on_before_optimizer_step(
492
+ self,
493
+ model_ddp: distributed.DistributedDataParallel,
494
+ optimizer: torch.optim.Optimizer,
495
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
496
+ grad_scaler: torch.amp.GradScaler,
497
+ iteration: int = 0,
498
+ ) -> None:
499
+ if self.synchronize:
500
+ torch.cuda.synchronize()
501
+ torch.cuda.nvtx.range_push("optimizer_step")
502
+
503
+ def on_before_zero_grad(
504
+ self,
505
+ model_ddp: distributed.DistributedDataParallel,
506
+ optimizer: torch.optim.Optimizer,
507
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
508
+ iteration: int = 0,
509
+ ) -> None:
510
+ if self.synchronize:
511
+ torch.cuda.synchronize()
512
+ torch.cuda.nvtx.range_pop()
513
+
514
+ def on_before_dataloading(self, iteration: int = 0) -> None:
515
+ torch.cuda.nvtx.range_push("dataloading")
516
+
517
+ def on_after_dataloading(self, iteration: int = 0) -> None:
518
+ torch.cuda.nvtx.range_pop()
imaginaire/utils/checkpointer.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import os
19
+ import threading
20
+ from typing import TYPE_CHECKING, NamedTuple
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+ from torch import nn
25
+
26
+ from imaginaire.model import ImaginaireModel
27
+ from imaginaire.utils import callback, distributed, log, misc
28
+ from imaginaire.utils.parallelism import ModelWrapper
29
+
30
+ if TYPE_CHECKING:
31
+ from imaginaire.config import CheckpointConfig, JobConfig
32
+
33
+
34
+ class Checkpointer:
35
+ """The checkpointer class. Supports checkpoint saving/loading to local disk."""
36
+
37
+ def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
38
+ """Constructor of the checkpointer.
39
+
40
+ Args:
41
+ config_checkpoint (CheckpointConfig): The config object for the checkpointer.
42
+ """
43
+ # Set the callback functions.
44
+ self.callbacks = callbacks
45
+ self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
46
+ self.strict_resume = config_checkpoint.strict_resume
47
+ self.load_path = config_checkpoint.load_path or None
48
+ self.load_training_state = config_checkpoint.load_training_state
49
+ self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
50
+ self.save_thread = None
51
+
52
+ def save(
53
+ self,
54
+ model: ImaginaireModel,
55
+ optimizer: torch.optim.Optimizer,
56
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
57
+ grad_scaler: torch.amp.GradScaler,
58
+ iteration: int,
59
+ ) -> None:
60
+ """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
61
+
62
+ Args:
63
+ model (ImaginaireModel): The PyTorch model.
64
+ optimizer (torch.optim.Optimizer): The model optimizer.
65
+ scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
66
+ grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
67
+ iteration (int): Current iteration number.
68
+ """
69
+ self.callbacks.on_save_checkpoint_start(model, iteration)
70
+
71
+ checkpoint_file = f"iter_{iteration:09}.pt"
72
+
73
+ if distributed.get_rank() == 0:
74
+ state_dict = dict(
75
+ model=model.state_dict(),
76
+ optimizer=optimizer.state_dict(),
77
+ scheduler=scheduler.state_dict(),
78
+ grad_scaler=grad_scaler.state_dict(),
79
+ iteration=iteration,
80
+ )
81
+ state_dict = misc.to(state_dict, device="cpu")
82
+ self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
83
+ # Wait for previous saver thread to end.
84
+ if self.save_thread:
85
+ self.save_thread.join()
86
+ # Run the checkpoint saver in a separate thread.
87
+ self.save_thread = threading.Thread(
88
+ target=self._save_worker_local,
89
+ daemon=False,
90
+ args=(state_dict, checkpoint_file, distributed.get_rank()),
91
+ )
92
+ self.save_thread.start()
93
+
94
+ # Note: Checkpoints are saved on a separate thread and this callback is not accurate.
95
+ # Please check logs from on_save_checkpoint_success() for better accuracy
96
+ self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
97
+
98
+ @misc.timer("checkpoint saving (local)")
99
+ def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
100
+ """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
101
+
102
+ Args:
103
+ state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
104
+ checkpoint_file (str): The file name of the model checkpoint.
105
+ rank (int): GPU device (default: 0).
106
+ """
107
+ checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
108
+ os.makedirs(self.checkpoint_dir_local, exist_ok=True)
109
+ try:
110
+ torch.save(state_dict, checkpoint_path)
111
+ if rank == 0:
112
+ self._write_latest_checkpoint_file(checkpoint_file)
113
+ log.success(f"Saved checkpoint (local): {checkpoint_path}")
114
+ iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
115
+ self.callbacks.on_save_checkpoint_success(iteration=iteration)
116
+ except Exception as e:
117
+ log.exception(f"Checkpoint failed to save (local): {e}")
118
+
119
+ @misc.timer("checkpoint loading")
120
+ def load(
121
+ self,
122
+ model: ImaginaireModel,
123
+ optimizer: torch.optim.Optimizer | None = None,
124
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
125
+ grad_scaler: torch.amp.GradScaler | None = None,
126
+ ) -> int:
127
+ """Load network weights and optimizer states from a checkpoint in a single process.
128
+
129
+ The priority of the checkpoint loading logic is:
130
+ 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
131
+ 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
132
+ - This is typically used for inference mode.
133
+ - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
134
+ 3. If none of the above, randomly initialize the model parameters and train from scratch.
135
+
136
+ Args:
137
+ model (ImaginaireModel): The PyTorch model.
138
+ optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
139
+ scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
140
+ grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
141
+
142
+ Returns:
143
+ iteration (int): the iteration number to start/resume from.
144
+ """
145
+ self.callbacks.on_load_checkpoint_start(model)
146
+
147
+ latest_checkpoint_file = self._read_latest_checkpoint_file()
148
+ if latest_checkpoint_file is not None:
149
+ # 1. Resume training from latest_checkpoint.txt under the same name.
150
+ checkpoint_dir = self.checkpoint_dir_local
151
+ checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
152
+ resume = True
153
+ only_resume_scheduler = True
154
+ else:
155
+ if self.load_path:
156
+ # 2. Load the module weights specified by config_checkpoint.path.
157
+ checkpoint_path = self.load_path
158
+ resume = self.load_training_state
159
+ only_resume_scheduler = self.only_load_scheduler_state
160
+ else:
161
+ # 3. Randomly initialize the model parameters and train from scratch.
162
+ checkpoint_path = None
163
+ resume = False
164
+ only_resume_scheduler = False
165
+ # Load checkpoint.
166
+ if checkpoint_path is not None:
167
+ self._check_checkpoint_exists(checkpoint_path)
168
+ log.info(f"Loading checkpoint (local): {checkpoint_path}")
169
+ state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
170
+ log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
171
+ self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
172
+ # Load the state dicts.
173
+ log.info("- Loading the model...")
174
+ model.load_state_dict(state_dict["model"], strict=self.strict_resume)
175
+ if resume or only_resume_scheduler:
176
+ iteration = state_dict["iteration"]
177
+ assert scheduler
178
+ log.info("- Loading the scheduler...")
179
+ scheduler.load_state_dict(state_dict["scheduler"])
180
+ scheduler.last_epoch = iteration
181
+ else:
182
+ iteration = 0
183
+ if resume:
184
+ assert optimizer
185
+ log.info("- Loading the optimizer...")
186
+ optimizer.load_state_dict(state_dict["optimizer"])
187
+ log.info("- Loading the gradient scaler...")
188
+ grad_scaler.load_state_dict(state_dict["grad_scaler"])
189
+ log.success(f"Done with loading the checkpoint (iteration {iteration}).")
190
+ else:
191
+ log.success("Done with loading the checkpoint.")
192
+ else:
193
+ # Checkpoint not found and not specified. We will train everything from scratch.
194
+ iteration = 0
195
+ log.info("Training from scratch.")
196
+ torch.cuda.empty_cache()
197
+
198
+ self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
199
+
200
+ return iteration
201
+
202
+ def _read_latest_checkpoint_file(self) -> str | None:
203
+ """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
204
+
205
+ Returns:
206
+ checkpoint_file (str | None): file name of the latest saved checkpoint.
207
+ """
208
+ checkpoint_file = None
209
+ latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
210
+ if os.path.isfile(latest_path):
211
+ checkpoint_file = open(latest_path).read().strip()
212
+ return checkpoint_file
213
+
214
+ def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
215
+ """Track the file name of the latest saved checkpoint.
216
+
217
+ Args:
218
+ checkpoint_file (str): file name of the latest saved checkpoint.
219
+ """
220
+ content = f"{checkpoint_file}\n"
221
+ latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
222
+ with open(latest_path, "w") as file:
223
+ file.write(content)
224
+
225
+ def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
226
+ """If the file checkpoint_path does not exist, raise an error.
227
+
228
+ Args:
229
+ checkpoint_path (str): full path to the checkpoint.
230
+ """
231
+ if not os.path.exists(checkpoint_path):
232
+ raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
233
+
234
+ def finalize(self) -> None:
235
+ """Finalize the checkpointer."""
236
+ if self.save_thread:
237
+ self.save_thread.join()
238
+
239
+
240
+ class _IncompatibleKeys(
241
+ NamedTuple(
242
+ "IncompatibleKeys",
243
+ [
244
+ ("missing_keys", list[str]),
245
+ ("unexpected_keys", list[str]),
246
+ ("incorrect_shapes", list[tuple[str, tuple[int], tuple[int]]]),
247
+ ],
248
+ )
249
+ ):
250
+ pass
251
+
252
+
253
+ def load_checkpoint(
254
+ model_parts: list[nn.Module],
255
+ ckpt_dir,
256
+ model_ckpt_key_map: dict[str, str] = {}, # noqa: B006
257
+ ):
258
+ log.info(f"Loading checkpoint from {ckpt_dir}.")
259
+
260
+ _model_wrapper = ModelWrapper(model_parts)
261
+ state_dict = _model_wrapper.state_dict()
262
+ # remove _extra_state
263
+ state_dict = {k: v for k, v in state_dict.items() if not k.endswith("._extra_state")}
264
+
265
+ # remap keys if needed
266
+ if model_ckpt_key_map:
267
+ for model_key, checkpoint_key in model_ckpt_key_map.items():
268
+ state_dict[checkpoint_key] = state_dict.pop(model_key)
269
+ log.info(f"Re-mapping {model_key} to {checkpoint_key}")
270
+
271
+ fs_storage_reader = dist.checkpoint.FileSystemReader(ckpt_dir)
272
+ dist.checkpoint.load(state_dict=state_dict, storage_reader=fs_storage_reader)
273
+
274
+ # inverse the remapping if needed
275
+ if model_ckpt_key_map:
276
+ for model_key, checkpoint_key in model_ckpt_key_map.items():
277
+ state_dict[model_key] = state_dict.pop(checkpoint_key)
278
+ log.info(f"Inverse re-mapping {checkpoint_key} to {model_key}")
279
+
280
+ _model_wrapper.load_state_dict(state_dict)
281
+
282
+ log.info(f"Finished loading checkpoint from {ckpt_dir}.")
imaginaire/utils/config_helper.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+ import pkgutil
19
+ import sys
20
+ from dataclasses import fields as dataclass_fields
21
+ from dataclasses import is_dataclass
22
+ from typing import Any
23
+
24
+ import attr
25
+ import attrs
26
+ from hydra import compose, initialize
27
+ from hydra.core.config_store import ConfigStore
28
+ from hydra.core.global_hydra import GlobalHydra
29
+ from omegaconf import DictConfig, OmegaConf
30
+
31
+ from imaginaire.config import Config
32
+ from imaginaire.utils import log
33
+
34
+
35
+ def is_attrs_or_dataclass(obj) -> bool:
36
+ """
37
+ Check if the object is an instance of an attrs class or a dataclass.
38
+
39
+ Args:
40
+ obj: The object to check.
41
+
42
+ Returns:
43
+ bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
44
+ """
45
+ return is_dataclass(obj) or attr.has(type(obj))
46
+
47
+
48
+ def get_fields(obj):
49
+ """
50
+ Get the fields of an attrs class or a dataclass.
51
+
52
+ Args:
53
+ obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
54
+
55
+ Returns:
56
+ list: A list of field names.
57
+
58
+ Raises:
59
+ ValueError: If the object is neither an attrs class nor a dataclass.
60
+ """
61
+ if is_dataclass(obj):
62
+ return [field.name for field in dataclass_fields(obj)]
63
+ elif attr.has(type(obj)):
64
+ return [field.name for field in attr.fields(type(obj))]
65
+ else:
66
+ raise ValueError("The object is neither an attrs class nor a dataclass.")
67
+
68
+
69
+ def override(config: Config, overrides: list[str] | None = None) -> Config:
70
+ """
71
+ :param config: the instance of class `Config` (usually from `make_config`)
72
+ :param overrides: list of overrides for config
73
+ :return: the composed instance of class `Config`
74
+ """
75
+ # Store the class of the config for reconstruction after overriding.
76
+ # config_class = type(config)
77
+
78
+ # Convert Config object to a DictConfig object
79
+ config_dict = attrs.asdict(config)
80
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
81
+ # Enforce "--" separator between the script arguments and overriding configs.
82
+ if overrides:
83
+ if overrides[0] != "--":
84
+ raise ValueError('Hydra config overrides must be separated with a "--" token.')
85
+ overrides = overrides[1:]
86
+ # Use Hydra to handle overrides
87
+ cs = ConfigStore.instance()
88
+ cs.store(name="config", node=config_omegaconf)
89
+ if not GlobalHydra().is_initialized():
90
+ with initialize(version_base=None):
91
+ config_omegaconf = compose(config_name="config", overrides=overrides)
92
+ OmegaConf.resolve(config_omegaconf)
93
+ else:
94
+ config_omegaconf = compose(config_name="config", overrides=overrides)
95
+ OmegaConf.resolve(config_omegaconf)
96
+
97
+ def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
98
+ """
99
+ Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
100
+
101
+ Args:
102
+ ref_instance: The reference instance to determine the type and fields when needed
103
+ kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
104
+
105
+ Returns:
106
+ Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
107
+
108
+ Raises:
109
+ AssertionError: If the fields do not match or if extra keys are found.
110
+ Exception: If there is an error constructing the new instance.
111
+ """
112
+ is_type = is_attrs_or_dataclass(ref_instance)
113
+ if not is_type:
114
+ return kwargs
115
+ else:
116
+ ref_fields = set(get_fields(ref_instance))
117
+ assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
118
+ "kwargs must be a dictionary or a DictConfig"
119
+ )
120
+ keys = set(kwargs.keys())
121
+
122
+ # ref_fields must equal to or include all keys
123
+ extra_keys = keys - ref_fields
124
+ assert ref_fields == keys or keys.issubset(ref_fields), (
125
+ f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
126
+ )
127
+
128
+ resolved_kwargs: dict[str, Any] = {}
129
+ for f in keys:
130
+ resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
131
+ try:
132
+ new_instance = type(ref_instance)(**resolved_kwargs)
133
+ except Exception as e:
134
+ log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
135
+ log.error(e)
136
+ raise e
137
+ return new_instance
138
+
139
+ config = config_from_dict(config, config_omegaconf)
140
+
141
+ return config
142
+
143
+
144
+ def get_config_module(config_file: str) -> str:
145
+ if not config_file.endswith(".py"):
146
+ log.error("Config file cannot be specified as module.")
147
+ log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
148
+ assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
149
+ # Convert to importable module format.
150
+ config_module = config_file.replace("/", ".").replace(".py", "")
151
+ return config_module
152
+
153
+
154
+ def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
155
+ """
156
+ Import all modules from the specified package path recursively.
157
+
158
+ This function is typically used in conjunction with Hydra to ensure that all modules
159
+ within a specified package are imported, which is necessary for registering configurations.
160
+
161
+ Example usage:
162
+ ```python
163
+ import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
164
+ ```
165
+
166
+ Args:
167
+ package_path (str): The dotted path to the package from which to import all modules.
168
+ reload (bool): Flag to determine whether to reload modules if they're already imported.
169
+ skip_underscore (bool): If True, skips importing modules that start with an underscore.
170
+ """
171
+ log.critical(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
172
+ package = importlib.import_module(package_path)
173
+ package_directory = package.__path__
174
+
175
+ def import_modules_recursively(directory: str, prefix: str) -> None:
176
+ """
177
+ Recursively imports or reloads all modules in the given directory.
178
+
179
+ Args:
180
+ directory (str): The file system path to the current package directory.
181
+ prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
182
+ """
183
+ for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
184
+ if skip_underscore and module_name.startswith("_"):
185
+ log.debug(f"Skipping module {module_name} as it starts with an underscore")
186
+ continue
187
+
188
+ full_module_name = f"{prefix}.{module_name}"
189
+ log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
190
+
191
+ if full_module_name in sys.modules and reload:
192
+ importlib.reload(sys.modules[full_module_name])
193
+ else:
194
+ importlib.import_module(full_module_name)
195
+
196
+ if is_pkg:
197
+ sub_package_directory = os.path.join(directory, module_name)
198
+ import_modules_recursively(sub_package_directory, full_module_name)
199
+
200
+ for directory in package_directory:
201
+ import_modules_recursively(directory, package_path)
imaginaire/utils/device.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import os
18
+
19
+ import pynvml
20
+
21
+
22
+ class Device:
23
+ _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
24
+
25
+ def __init__(self, device_idx: int):
26
+ super().__init__()
27
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
28
+
29
+ def get_name(self) -> str:
30
+ return pynvml.nvmlDeviceGetName(self.handle)
31
+
32
+ def get_cpu_affinity(self) -> list[int]:
33
+ affinity_string = ""
34
+ for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
35
+ # assume nvml returns list of 64 bit ints
36
+ affinity_string = f"{j:064b}" + affinity_string
37
+ affinity_list = [int(x) for x in affinity_string]
38
+ affinity_list.reverse() # so core 0 is in 0th element of list
39
+ return [i for i, e in enumerate(affinity_list) if e != 0]
imaginaire/utils/distributed.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import collections
19
+ import collections.abc
20
+ import ctypes
21
+ import functools
22
+ import os
23
+ from collections.abc import Callable, Container
24
+ from contextlib import contextmanager
25
+ from datetime import timedelta
26
+ from typing import TYPE_CHECKING, Any
27
+
28
+ import pynvml
29
+ import torch
30
+ import torch.distributed as dist
31
+ from torch.distributed import get_process_group_ranks
32
+
33
+ from imaginaire.utils.device import Device
34
+
35
+ if dist.is_available():
36
+ from torch.distributed.distributed_c10d import _get_default_group
37
+ from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
38
+
39
+ from imaginaire.utils import log
40
+
41
+ if TYPE_CHECKING:
42
+ from imaginaire.config import DDPConfig
43
+
44
+ try:
45
+ from megatron.core import parallel_state
46
+ except ImportError:
47
+ print("Megatron-core is not installed.")
48
+
49
+
50
+ def init() -> int | None:
51
+ """Initialize distributed training."""
52
+ if dist.is_initialized():
53
+ return torch.cuda.current_device()
54
+
55
+ # Set GPU affinity.
56
+ pynvml.nvmlInit()
57
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
58
+ try:
59
+ device = Device(local_rank)
60
+ os.sched_setaffinity(0, device.get_cpu_affinity())
61
+ except (OSError, pynvml.NVMLError) as e:
62
+ log.warning(f"Failed to set device affinity: {e}")
63
+ # Set up NCCL communication.
64
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
65
+ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
66
+ if dist.is_available():
67
+ torch.cuda.set_device(local_rank)
68
+ # Get the timeout value from environment variable
69
+ timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
70
+ # Convert the timeout to an integer (if it isn't already) and then to a timedelta
71
+ timeout_timedelta = timedelta(seconds=int(timeout_seconds))
72
+ dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
73
+ log.info(
74
+ f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
75
+ rank0_only=False,
76
+ )
77
+ # Increase the L2 fetch granularity for faster speed.
78
+ _libcudart = ctypes.CDLL("libcudart.so")
79
+ # Set device limit on the current device.
80
+ p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
81
+ _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
82
+ _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
83
+ log.info(f"Training with {get_world_size()} GPUs.")
84
+
85
+
86
+ def get_rank(group: dist.ProcessGroup | None = None) -> int:
87
+ """Get the rank (GPU device) of the worker.
88
+
89
+ Returns:
90
+ rank (int): The rank of the worker.
91
+ """
92
+ rank = 0
93
+ if dist.is_available() and dist.is_initialized():
94
+ rank = dist.get_rank(group)
95
+ return rank
96
+
97
+
98
+ def get_world_size(group: dist.ProcessGroup | None = None) -> int:
99
+ """Get world size. How many GPUs are available in this job.
100
+
101
+ Returns:
102
+ world_size (int): The total number of GPUs available in this job.
103
+ """
104
+ world_size = 1
105
+ if dist.is_available() and dist.is_initialized():
106
+ world_size = dist.get_world_size(group)
107
+ return world_size
108
+
109
+
110
+ def is_rank0() -> bool:
111
+ """Check if current process is the master GPU.
112
+
113
+ Returns:
114
+ (bool): True if this function is called from the master GPU, else False.
115
+ """
116
+ return get_rank() == 0
117
+
118
+
119
+ def is_local_rank0() -> bool:
120
+ """Check if current process is the local master GPU in the current node.
121
+
122
+ Returns:
123
+ (bool): True if this function is called from the local master GPU, else False.
124
+ """
125
+ return torch.cuda.current_device() == 0
126
+
127
+
128
+ def rank0_only(func: Callable) -> Callable:
129
+ """Apply this function only to the master GPU.
130
+
131
+ Example usage:
132
+ @rank0_only
133
+ def func(x):
134
+ return x + 3
135
+
136
+ Args:
137
+ func (Callable): a function.
138
+
139
+ Returns:
140
+ (Callable): A function wrapper executing the function only on the master GPU.
141
+ """
142
+
143
+ @functools.wraps(func)
144
+ def wrapper(*args, **kwargs):
145
+ if is_rank0():
146
+ return func(*args, **kwargs)
147
+ else:
148
+ return None
149
+
150
+ return wrapper
151
+
152
+
153
+ def barrier() -> None:
154
+ """Barrier for all GPUs."""
155
+ if dist.is_available() and dist.is_initialized():
156
+ dist.barrier()
157
+
158
+
159
+ def rank0_first(func: Callable) -> Callable:
160
+ """run the function on rank 0 first, then on other ranks."""
161
+
162
+ @functools.wraps(func)
163
+ def wrapper(*args, **kwargs):
164
+ if is_rank0():
165
+ result = func(*args, **kwargs)
166
+ barrier()
167
+ if not is_rank0():
168
+ result = func(*args, **kwargs)
169
+ return result
170
+
171
+ return wrapper
172
+
173
+
174
+ def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
175
+ """Wraps the model to enable data parallalism for training across multiple GPU devices.
176
+
177
+ Args:
178
+ config_ddp (DDPConfig): The data parallel config.
179
+ model (torch.nn.Module): The PyTorch module.
180
+
181
+ Returns:
182
+ model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
183
+ if distributed environment is available, otherwise return the original model.
184
+ """
185
+ if dist.is_available() and dist.is_initialized():
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ try:
188
+ ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
189
+ except Exception as e:
190
+ log.info(e)
191
+ log.info("parallel_state not initialized, treating all GPUs equally for DDP")
192
+ ddp_group = None
193
+
194
+ model = DistributedDataParallel(
195
+ model,
196
+ device_ids=[local_rank],
197
+ output_device=local_rank,
198
+ find_unused_parameters=config_ddp.find_unused_parameters,
199
+ static_graph=config_ddp.static_graph,
200
+ broadcast_buffers=config_ddp.broadcast_buffers,
201
+ process_group=ddp_group,
202
+ )
203
+ return model
204
+
205
+
206
+ class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
207
+ """This extends torch.nn.parallel.DistributedDataParallel with .training_step().
208
+
209
+ This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
210
+ model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
211
+ model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
212
+ training_step), allowing us to preserve the function names and signatures.
213
+ """
214
+
215
+ def __init__(self, model: torch.nn.Module, *args, **kwargs):
216
+ super().__init__(model, *args, **kwargs)
217
+ self.show_sync_grad_static_graph_warning = True
218
+
219
+ def training_step(self, *args, **kwargs) -> Any:
220
+ # Cache the original model.forward() method.
221
+ original_forward = self.module.forward
222
+
223
+ def wrapped_training_step(*_args, **_kwargs):
224
+ # Unpatch immediately before calling training_step() because itself may want to call the real forward.
225
+ self.module.forward = original_forward
226
+ # The actual .training_step().
227
+ return self.module.training_step(*_args, **_kwargs)
228
+
229
+ # Patch the original_module's forward so we can redirect the arguments back to the real method.
230
+ self.module.forward = wrapped_training_step
231
+ # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
232
+ # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
233
+ return self(*args, **kwargs)
234
+
235
+
236
+ @contextmanager
237
+ def ddp_sync_grad(model, enabled):
238
+ r"""
239
+ Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
240
+ Modified from:
241
+ https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
242
+ Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
243
+
244
+ Within this context, gradients will be accumulated on module
245
+ variables, which will later be synchronized in the first
246
+ forward-backward pass exiting the context.
247
+
248
+ .. warning::
249
+ The forward pass should be included inside the context manager, or
250
+ else gradients will still be synchronized.
251
+ """
252
+ assert isinstance(model, torch.nn.Module)
253
+ if isinstance(model, DistributedDataParallel):
254
+ old_require_backward_grad_sync = model.require_backward_grad_sync
255
+ if model.static_graph and model.require_backward_grad_sync != enabled:
256
+ if model.show_sync_grad_static_graph_warning:
257
+ log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
258
+ model.show_sync_grad_static_graph_warning = False
259
+ else:
260
+ model.require_backward_grad_sync = enabled
261
+ try:
262
+ yield
263
+ finally:
264
+ if isinstance(model, DistributedDataParallel):
265
+ model.require_backward_grad_sync = old_require_backward_grad_sync
266
+
267
+
268
+ def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
269
+ """Aggregate the list of data batches from all devices and process the results.
270
+
271
+ This is used for gathering validation data batches with imaginaire.utils.dataloader.DistributedEvalSampler.
272
+ It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
273
+ in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
274
+ created before calling dis.all_gather().
275
+
276
+ Args:
277
+ data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
278
+ leaf entries are tensors.
279
+
280
+ Returns:
281
+ data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
282
+ leaf entries are concatenated tensors.
283
+ """
284
+ if isinstance(data_batches[0], torch.Tensor):
285
+ # Concatenate the local data batches.
286
+ data_concat = torch.cat(data_batches, dim=0) # type: ignore
287
+ # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
288
+ max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
289
+ dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
290
+ if len(data_concat) < max_num_local_samples:
291
+ assert len(data_concat) + 1 == max_num_local_samples
292
+ dummy = torch.empty_like(data_concat[:1])
293
+ data_concat = torch.cat([data_concat, dummy], dim=0)
294
+ dummy_count = torch.tensor(1, device="cuda")
295
+ else:
296
+ dummy_count = torch.tensor(0, device="cuda")
297
+ # Get all concatenated batches from all ranks and concatenate again.
298
+ dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
299
+ data_concat = all_gather_tensor(data_concat.contiguous())
300
+ data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
301
+ # Remove the dummy samples.
302
+ if dummy_count > 0:
303
+ data_collate = data_collate[:-dummy_count]
304
+ elif isinstance(data_batches[0], collections.abc.Mapping):
305
+ data_collate = dict()
306
+ for key in data_batches[0].keys():
307
+ data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
308
+ else:
309
+ raise TypeError
310
+ return data_collate
311
+
312
+
313
+ @torch.no_grad()
314
+ def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
315
+ """Gather the corresponding tensor from all GPU devices to a list.
316
+
317
+ Args:
318
+ tensor (torch.Tensor): Pytorch tensor.
319
+
320
+ Returns:
321
+ tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
322
+ """
323
+ tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
324
+ dist.all_gather(tensor_list, tensor)
325
+ return tensor_list
326
+
327
+
328
+ def broadcast(tensor, src, group=None, async_op=False):
329
+ world_size = get_world_size()
330
+ if world_size < 2:
331
+ return tensor
332
+ dist.broadcast(tensor, src=src, group=group, async_op=async_op)
333
+
334
+
335
+ def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
336
+ r"""Reduce to rank 0"""
337
+ world_size = get_world_size()
338
+ if world_size < 2:
339
+ return tensor
340
+ with torch.no_grad():
341
+ dist.reduce(tensor, dst=rank)
342
+ if get_rank() == rank:
343
+ if reduce == "mean":
344
+ tensor /= world_size
345
+ elif reduce == "sum":
346
+ pass
347
+ else:
348
+ raise NotImplementedError
349
+ return tensor
350
+
351
+
352
+ def sync_model_states(
353
+ model: torch.nn.Module,
354
+ process_group: dist.ProcessGroup | None = None,
355
+ src: int = 0,
356
+ params_and_buffers_to_ignore: Container[str] | None = None,
357
+ broadcast_buffers: bool = True,
358
+ ):
359
+ """
360
+ Modify based on DDP source code
361
+ Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
362
+
363
+ This function ensures that all processes in the specified process group have the same initial parameters and
364
+ buffers from the source rank, typically rank 0. It is useful when different processes start with different model
365
+ states and a synchronization is required to ensure consistency across all ranks.
366
+
367
+ Args:
368
+ model (nn.Module): The model whose parameters and buffers are to be synchronized.
369
+ process_group (dist.ProcessGroup, optional): The process group for communication. If None,
370
+ the default group is used. Defaults to None.
371
+ src (int, optional): The source rank from which parameters and buffers will be broadcasted.
372
+ Defaults to 0.
373
+ params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
374
+ names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
375
+ included.
376
+ broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
377
+
378
+ Side Effects:
379
+ This function modifies the state of the model in-place to synchronize it with the source rank's model state.
380
+
381
+ Raises:
382
+ RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
383
+
384
+ Examples:
385
+ >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
386
+ >>> # useful and save our time when model weights are huge
387
+ >>> if dist.get_rank == 0:
388
+ >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
389
+ >>> dist.barrir()
390
+ >>> sync_model_states(model) # sync rank0 weights to other ranks
391
+ """
392
+ if not dist.is_available() or not dist.is_initialized():
393
+ return
394
+ if process_group is None:
395
+ process_group = _get_default_group()
396
+ if not params_and_buffers_to_ignore:
397
+ params_and_buffers_to_ignore = set()
398
+
399
+ log.info(
400
+ f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
401
+ )
402
+
403
+ # Build tuple of (module, parameter) for all parameters that require grads.
404
+ modules_and_parameters = [
405
+ (module, parameter)
406
+ for module_name, module in model.named_modules()
407
+ for parameter in [
408
+ param
409
+ # Note that we access module.named_parameters instead of
410
+ # parameters(module). parameters(module) is only needed in the
411
+ # single-process multi device case, where it accesses replicated
412
+ # parameters through _former_parameters.
413
+ for param_name, param in module.named_parameters(recurse=False)
414
+ if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
415
+ # if param.requires_grad
416
+ # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
417
+ ]
418
+ ]
419
+
420
+ # Deduplicate any parameters that might be shared across child modules.
421
+ memo = set()
422
+ modules_and_parameters = [
423
+ # "p not in memo" is the deduplication check.
424
+ # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
425
+ (m, p)
426
+ for m, p in modules_and_parameters
427
+ if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
428
+ ]
429
+
430
+ # Build list of parameters.
431
+ parameters = [parameter for _, parameter in modules_and_parameters]
432
+ if len(parameters) == 0:
433
+ return
434
+
435
+ _verify_param_shape_across_processes(process_group, parameters)
436
+
437
+ _sync_module_states(
438
+ module=model,
439
+ process_group=process_group,
440
+ broadcast_bucket_size=(250 * 1024 * 1024),
441
+ src=src,
442
+ params_and_buffers_to_ignore=params_and_buffers_to_ignore,
443
+ broadcast_buffers=broadcast_buffers,
444
+ )
imaginaire/utils/easy_io/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
imaginaire/utils/easy_io/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (144 Bytes). View file
 
imaginaire/utils/easy_io/__pycache__/easy_io.cpython-310.pyc ADDED
Binary file (28.6 kB). View file
 
imaginaire/utils/easy_io/__pycache__/file_client.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
imaginaire/utils/easy_io/backends/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
17
+ from imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
18
+ from imaginaire.utils.easy_io.backends.local_backend import LocalBackend
19
+ from imaginaire.utils.easy_io.backends.registry_utils import backends, prefix_to_backends, register_backend
20
+
21
+ __all__ = [
22
+ "BaseStorageBackend",
23
+ "HTTPBackend",
24
+ "LocalBackend",
25
+ "backends",
26
+ "prefix_to_backends",
27
+ "register_backend",
28
+ ]
imaginaire/utils/easy_io/backends/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (596 Bytes). View file
 
imaginaire/utils/easy_io/backends/__pycache__/base_backend.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
imaginaire/utils/easy_io/backends/__pycache__/http_backend.cpython-310.pyc ADDED
Binary file (2.9 kB). View file
 
imaginaire/utils/easy_io/backends/__pycache__/local_backend.cpython-310.pyc ADDED
Binary file (18.4 kB). View file
 
imaginaire/utils/easy_io/backends/__pycache__/registry_utils.cpython-310.pyc ADDED
Binary file (3.57 kB). View file