Spaces:
Sleeping
Sleeping
we can load model an handle oom errors on model switch
Browse files
main.py
CHANGED
@@ -112,7 +112,27 @@ def setup(args, loaded_model_setup=None):
|
|
112 |
pipe = get_model(
|
113 |
args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
|
114 |
)
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
torch.cuda.empty_cache() # Free up cached memory
|
117 |
gc.collect()
|
118 |
|
@@ -123,7 +143,7 @@ def setup(args, loaded_model_setup=None):
|
|
123 |
n_inference_steps=args.n_inference_steps,
|
124 |
seed=args.seed,
|
125 |
save_all_images=args.save_all_images,
|
126 |
-
device=device,
|
127 |
no_optim=args.no_optim,
|
128 |
regularize=args.enable_reg,
|
129 |
regularization_weight=args.reg_weight,
|
@@ -132,9 +152,6 @@ def setup(args, loaded_model_setup=None):
|
|
132 |
imageselect=args.imageselect,
|
133 |
)
|
134 |
|
135 |
-
torch.cuda.empty_cache() # Free up cached memory
|
136 |
-
gc.collect()
|
137 |
-
|
138 |
# Create latents
|
139 |
if args.model == "flux":
|
140 |
shape = (1, 16 * 64, 64)
|
@@ -159,16 +176,17 @@ def setup(args, loaded_model_setup=None):
|
|
159 |
|
160 |
enable_grad = not args.no_optim
|
161 |
|
|
|
162 |
torch.cuda.empty_cache() # Free up cached memory
|
163 |
gc.collect()
|
164 |
-
|
165 |
if args.enable_multi_apply:
|
166 |
multi_apply_fn = get_multi_apply_fn(
|
167 |
model_type=args.multi_step_model,
|
168 |
seed=args.seed,
|
169 |
pipe=pipe,
|
170 |
cache_dir=args.cache_dir,
|
171 |
-
device=device,
|
172 |
dtype=dtype,
|
173 |
)
|
174 |
else:
|
@@ -179,6 +197,8 @@ def setup(args, loaded_model_setup=None):
|
|
179 |
|
180 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
181 |
|
|
|
|
|
182 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
|
183 |
|
184 |
if args.task == "single":
|
|
|
112 |
pipe = get_model(
|
113 |
args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
|
114 |
)
|
115 |
+
|
116 |
+
# Attempt to move the model to GPU or keep it on CPU if offloading is enabled
|
117 |
+
try:
|
118 |
+
if not args.cpu_offloading:
|
119 |
+
pipe.to(device)
|
120 |
+
except RuntimeError as e:
|
121 |
+
if 'out of memory' in str(e):
|
122 |
+
print("CUDA OOM error. Attempting to handle OOM situation.")
|
123 |
+
# Attempt to clear memory and retry moving to GPU
|
124 |
+
torch.cuda.empty_cache() # Free up cached memory
|
125 |
+
gc.collect()
|
126 |
+
try:
|
127 |
+
# Retry loading after clearing cache
|
128 |
+
if not args.cpu_offloading:
|
129 |
+
pipe.to(device)
|
130 |
+
except RuntimeError as e:
|
131 |
+
print("Still facing OOM issues. Keeping model on CPU.")
|
132 |
+
args.cpu_offloading = True # Force CPU offloading
|
133 |
+
else:
|
134 |
+
raise e # Re-raise the exception if it's not OOM
|
135 |
+
|
136 |
torch.cuda.empty_cache() # Free up cached memory
|
137 |
gc.collect()
|
138 |
|
|
|
143 |
n_inference_steps=args.n_inference_steps,
|
144 |
seed=args.seed,
|
145 |
save_all_images=args.save_all_images,
|
146 |
+
device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled
|
147 |
no_optim=args.no_optim,
|
148 |
regularize=args.enable_reg,
|
149 |
regularization_weight=args.reg_weight,
|
|
|
152 |
imageselect=args.imageselect,
|
153 |
)
|
154 |
|
|
|
|
|
|
|
155 |
# Create latents
|
156 |
if args.model == "flux":
|
157 |
shape = (1, 16 * 64, 64)
|
|
|
176 |
|
177 |
enable_grad = not args.no_optim
|
178 |
|
179 |
+
# Final memory cleanup
|
180 |
torch.cuda.empty_cache() # Free up cached memory
|
181 |
gc.collect()
|
182 |
+
|
183 |
if args.enable_multi_apply:
|
184 |
multi_apply_fn = get_multi_apply_fn(
|
185 |
model_type=args.multi_step_model,
|
186 |
seed=args.seed,
|
187 |
pipe=pipe,
|
188 |
cache_dir=args.cache_dir,
|
189 |
+
device=device if not args.cpu_offloading else 'cpu',
|
190 |
dtype=dtype,
|
191 |
)
|
192 |
else:
|
|
|
197 |
|
198 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
199 |
|
200 |
+
|
201 |
+
|
202 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
|
203 |
|
204 |
if args.task == "single":
|