fffiloni commited on
Commit
3e6b0ce
1 Parent(s): 27a9419

we worked to handle gpu operations

Browse files
Files changed (1) hide show
  1. main.py +74 -45
main.py CHANGED
@@ -15,6 +15,36 @@ from rewards import get_reward_losses
15
  from training import LatentNoiseTrainer, get_optimizer
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def setup(args, loaded_model_setup=None):
19
  seed_everything(args.seed)
20
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
@@ -52,14 +82,10 @@ def setup(args, loaded_model_setup=None):
52
  os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
53
 
54
  device = torch.device("cuda")
55
- if args.dtype == "float32":
56
- dtype = torch.float32
57
- elif args.dtype == "float16":
58
- dtype = torch.float16
59
 
60
  # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
61
  if loaded_model_setup and args.model == loaded_model_setup[0].model:
62
- # Reuse the trainer and pipe from the loaded model setup
63
  print(f"Reusing model {args.model} from loaded setup.")
64
  trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
65
 
@@ -97,10 +123,13 @@ def setup(args, loaded_model_setup=None):
97
  width // trainer.model.vae_scale_factor,
98
  )
99
 
100
- multi_apply_fn = loaded_model_setup[6]
101
  enable_grad = not args.no_optim
102
 
103
- return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
 
 
 
104
 
105
  # Proceed with full model loading if args.model is different
106
  print(f"Loading new model: {args.model}")
@@ -113,27 +142,8 @@ def setup(args, loaded_model_setup=None):
113
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
114
  )
115
 
116
- # Attempt to move the model to GPU or keep it on CPU if offloading is enabled
117
- try:
118
- if not args.cpu_offloading:
119
- pipe.to(device)
120
- except RuntimeError as e:
121
- if 'out of memory' in str(e):
122
- print("CUDA OOM error. Attempting to handle OOM situation.")
123
- # Attempt to clear memory and retry moving to GPU
124
- torch.cuda.empty_cache() # Free up cached memory
125
- gc.collect()
126
- try:
127
- # Retry loading after clearing cache
128
- if not args.cpu_offloading:
129
- pipe.to(device)
130
- except RuntimeError as e:
131
- print("Still facing OOM issues. Keeping model on CPU.")
132
- args.cpu_offloading = True # Force CPU offloading
133
- else:
134
- raise e # Re-raise the exception if it's not OOM
135
-
136
- torch.cuda.empty_cache() # Free up cached memory
137
  gc.collect()
138
 
139
  trainer = LatentNoiseTrainer(
@@ -180,28 +190,47 @@ def setup(args, loaded_model_setup=None):
180
  torch.cuda.empty_cache() # Free up cached memory
181
  gc.collect()
182
 
183
- if args.enable_multi_apply:
184
- multi_apply_fn = get_multi_apply_fn(
185
- model_type=args.multi_step_model,
186
- seed=args.seed,
187
- pipe=pipe,
188
- cache_dir=args.cache_dir,
189
- device=device if not args.cpu_offloading else 'cpu',
190
- dtype=dtype,
191
- )
192
- else:
193
- multi_apply_fn = None
194
 
195
- torch.cuda.empty_cache() # Free up cached memory
196
- gc.collect()
197
 
198
- return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
199
 
200
 
201
 
202
- def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
203
 
204
  if args.task == "single":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  init_latents = torch.randn(shape, device=device, dtype=dtype)
206
  latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
207
  optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
@@ -383,8 +412,8 @@ def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_f
383
 
384
  def main():
385
  args = parse_args()
386
- args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings = setup(args, loaded_model_setup=None)
387
- execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings)
388
 
389
  if __name__ == "__main__":
390
  main()
 
15
  from training import LatentNoiseTrainer, get_optimizer
16
 
17
 
18
+ import torch
19
+ import gc
20
+
21
+ def clear_gpu():
22
+ """Clear GPU memory by removing tensors, freeing cache, and moving data to CPU."""
23
+ # List memory usage before clearing
24
+ print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
25
+ print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
26
+
27
+ # Force the garbage collector to free unreferenced objects
28
+ gc.collect()
29
+
30
+ # Move any bound tensors back to CPU if needed
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache() # Free up the cached memory
33
+ torch.cuda.ipc_collect() # Clear any cross-process memory
34
+
35
+ print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB")
36
+ print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB")
37
+
38
+ def unload_previous_model_if_needed(loaded_model_setup):
39
+ """Unload the current model from the GPU and free resources if a new model is being loaded."""
40
+ if loaded_model_setup is not None:
41
+ print("Unloading previous model from GPU to free memory.")
42
+ previous_model = loaded_model_setup[7] # Assuming pipe is at position [7] in the setup
43
+ if hasattr(previous_model, 'to') and loaded_model_setup[0].model != "flux":
44
+ previous_model.to('cpu') # Move model to CPU to free GPU memory
45
+ del previous_model # Delete the reference to the model
46
+ clear_gpu() # Clear all remaining GPU memory
47
+
48
  def setup(args, loaded_model_setup=None):
49
  seed_everything(args.seed)
50
  bf.makedirs(f"{args.save_dir}/logs/{args.task}")
 
82
  os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
83
 
84
  device = torch.device("cuda")
85
+ dtype = torch.float16 if args.dtype == "float16" else torch.float32
 
 
 
86
 
87
  # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe
88
  if loaded_model_setup and args.model == loaded_model_setup[0].model:
 
89
  print(f"Reusing model {args.model} from loaded setup.")
90
  trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup
91
 
 
123
  width // trainer.model.vae_scale_factor,
124
  )
125
 
126
+ pipe = loaded_model_setup[7]
127
  enable_grad = not args.no_optim
128
 
129
+ return args, trainer, device, dtype, shape, enable_grad, settings, pipe
130
+
131
+ # Unload previous model and clear GPU resources
132
+ unload_previous_model_if_needed(loaded_model_setup)
133
 
134
  # Proceed with full model loading if args.model is different
135
  print(f"Loading new model: {args.model}")
 
142
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
143
  )
144
 
145
+ # Final memory cleanup after model loading
146
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  gc.collect()
148
 
149
  trainer = LatentNoiseTrainer(
 
190
  torch.cuda.empty_cache() # Free up cached memory
191
  gc.collect()
192
 
193
+
 
 
 
 
 
 
 
 
 
 
194
 
195
+ return args, trainer, device, dtype, shape, enable_grad, settings, pipe
 
196
 
 
197
 
198
 
199
 
200
+ def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None):
201
 
202
  if args.task == "single":
203
+ # Attempt to move the model to GPU if model is not Flux
204
+ if args.model != "flux":
205
+ if pipe.device != torch.device('cuda'):
206
+ pipe.to(device, dtype)
207
+ else:
208
+ print(f"PIPE:{pipe}")
209
+
210
+
211
+ if args.cpu_offloading:
212
+ pipe.enable_sequential_cpu_offload()
213
+
214
+ #if pipe.device != torch.device('cuda'):
215
+ # pipe.to(device, dtype)
216
+
217
+ if args.enable_multi_apply:
218
+
219
+ multi_apply_fn = get_multi_apply_fn(
220
+ model_type=args.multi_step_model,
221
+ seed=args.seed,
222
+ pipe=pipe,
223
+ cache_dir=args.cache_dir,
224
+ device=device if not args.cpu_offloading else 'cpu',
225
+ dtype=dtype,
226
+ )
227
+ else:
228
+ multi_apply_fn = None
229
+
230
+ torch.cuda.empty_cache() # Free up cached memory
231
+ gc.collect()
232
+
233
+
234
  init_latents = torch.randn(shape, device=device, dtype=dtype)
235
  latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
236
  optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov)
 
412
 
413
  def main():
414
  args = parse_args()
415
+ args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup=None)
416
+ execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe)
417
 
418
  if __name__ == "__main__":
419
  main()