fffiloni commited on
Commit
38ce166
1 Parent(s): c839ac6

we can load model an handle oom errors on model switch

Browse files
Files changed (1) hide show
  1. main.py +27 -7
main.py CHANGED
@@ -112,7 +112,27 @@ def setup(args, loaded_model_setup=None):
112
  pipe = get_model(
113
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
114
  )
115
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  torch.cuda.empty_cache() # Free up cached memory
117
  gc.collect()
118
 
@@ -123,7 +143,7 @@ def setup(args, loaded_model_setup=None):
123
  n_inference_steps=args.n_inference_steps,
124
  seed=args.seed,
125
  save_all_images=args.save_all_images,
126
- device=device,
127
  no_optim=args.no_optim,
128
  regularize=args.enable_reg,
129
  regularization_weight=args.reg_weight,
@@ -132,9 +152,6 @@ def setup(args, loaded_model_setup=None):
132
  imageselect=args.imageselect,
133
  )
134
 
135
- torch.cuda.empty_cache() # Free up cached memory
136
- gc.collect()
137
-
138
  # Create latents
139
  if args.model == "flux":
140
  shape = (1, 16 * 64, 64)
@@ -159,16 +176,17 @@ def setup(args, loaded_model_setup=None):
159
 
160
  enable_grad = not args.no_optim
161
 
 
162
  torch.cuda.empty_cache() # Free up cached memory
163
  gc.collect()
164
-
165
  if args.enable_multi_apply:
166
  multi_apply_fn = get_multi_apply_fn(
167
  model_type=args.multi_step_model,
168
  seed=args.seed,
169
  pipe=pipe,
170
  cache_dir=args.cache_dir,
171
- device=device,
172
  dtype=dtype,
173
  )
174
  else:
@@ -179,6 +197,8 @@ def setup(args, loaded_model_setup=None):
179
 
180
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
181
 
 
 
182
  def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
183
 
184
  if args.task == "single":
 
112
  pipe = get_model(
113
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
114
  )
115
+
116
+ # Attempt to move the model to GPU or keep it on CPU if offloading is enabled
117
+ try:
118
+ if not args.cpu_offloading:
119
+ pipe.to(device)
120
+ except RuntimeError as e:
121
+ if 'out of memory' in str(e):
122
+ print("CUDA OOM error. Attempting to handle OOM situation.")
123
+ # Attempt to clear memory and retry moving to GPU
124
+ torch.cuda.empty_cache() # Free up cached memory
125
+ gc.collect()
126
+ try:
127
+ # Retry loading after clearing cache
128
+ if not args.cpu_offloading:
129
+ pipe.to(device)
130
+ except RuntimeError as e:
131
+ print("Still facing OOM issues. Keeping model on CPU.")
132
+ args.cpu_offloading = True # Force CPU offloading
133
+ else:
134
+ raise e # Re-raise the exception if it's not OOM
135
+
136
  torch.cuda.empty_cache() # Free up cached memory
137
  gc.collect()
138
 
 
143
  n_inference_steps=args.n_inference_steps,
144
  seed=args.seed,
145
  save_all_images=args.save_all_images,
146
+ device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled
147
  no_optim=args.no_optim,
148
  regularize=args.enable_reg,
149
  regularization_weight=args.reg_weight,
 
152
  imageselect=args.imageselect,
153
  )
154
 
 
 
 
155
  # Create latents
156
  if args.model == "flux":
157
  shape = (1, 16 * 64, 64)
 
176
 
177
  enable_grad = not args.no_optim
178
 
179
+ # Final memory cleanup
180
  torch.cuda.empty_cache() # Free up cached memory
181
  gc.collect()
182
+
183
  if args.enable_multi_apply:
184
  multi_apply_fn = get_multi_apply_fn(
185
  model_type=args.multi_step_model,
186
  seed=args.seed,
187
  pipe=pipe,
188
  cache_dir=args.cache_dir,
189
+ device=device if not args.cpu_offloading else 'cpu',
190
  dtype=dtype,
191
  )
192
  else:
 
197
 
198
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
199
 
200
+
201
+
202
  def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
203
 
204
  if args.task == "single":