DegMaTsu commited on
Commit
f44e53c
·
verified ·
1 Parent(s): 9fb17f6

Upload model_management.py

Browse files
Files changed (1) hide show
  1. comfy/model_management.py +1433 -1430
comfy/model_management.py CHANGED
@@ -1,1430 +1,1433 @@
1
- """
2
- This file is part of ComfyUI.
3
- Copyright (C) 2024 Comfy
4
-
5
- This program is free software: you can redistribute it and/or modify
6
- it under the terms of the GNU General Public License as published by
7
- the Free Software Foundation, either version 3 of the License, or
8
- (at your option) any later version.
9
-
10
- This program is distributed in the hope that it will be useful,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- GNU General Public License for more details.
14
-
15
- You should have received a copy of the GNU General Public License
16
- along with this program. If not, see <https://www.gnu.org/licenses/>.
17
- """
18
-
19
- import psutil
20
- import logging
21
- from enum import Enum
22
- from comfy.cli_args import args, PerformanceFeature
23
- import torch
24
- import sys
25
- import importlib
26
- import platform
27
- import weakref
28
- import gc
29
-
30
- class VRAMState(Enum):
31
- DISABLED = 0 #No vram present: no need to move models to vram
32
- NO_VRAM = 1 #Very low vram: enable all the options to save vram
33
- LOW_VRAM = 2
34
- NORMAL_VRAM = 3
35
- HIGH_VRAM = 4
36
- SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
37
-
38
- class CPUState(Enum):
39
- GPU = 0
40
- CPU = 1
41
- MPS = 2
42
-
43
- # Determine VRAM State
44
- vram_state = VRAMState.NORMAL_VRAM
45
- set_vram_to = VRAMState.NORMAL_VRAM
46
- cpu_state = CPUState.GPU
47
-
48
- total_vram = 0
49
-
50
- def get_supported_float8_types():
51
- float8_types = []
52
- try:
53
- float8_types.append(torch.float8_e4m3fn)
54
- except:
55
- pass
56
- try:
57
- float8_types.append(torch.float8_e4m3fnuz)
58
- except:
59
- pass
60
- try:
61
- float8_types.append(torch.float8_e5m2)
62
- except:
63
- pass
64
- try:
65
- float8_types.append(torch.float8_e5m2fnuz)
66
- except:
67
- pass
68
- try:
69
- float8_types.append(torch.float8_e8m0fnu)
70
- except:
71
- pass
72
- return float8_types
73
-
74
- FLOAT8_TYPES = get_supported_float8_types()
75
-
76
- xpu_available = False
77
- torch_version = ""
78
- try:
79
- torch_version = torch.version.__version__
80
- temp = torch_version.split(".")
81
- torch_version_numeric = (int(temp[0]), int(temp[1]))
82
- except:
83
- pass
84
-
85
- lowvram_available = True
86
- if args.deterministic:
87
- logging.info("Using deterministic algorithms for pytorch")
88
- torch.use_deterministic_algorithms(True, warn_only=True)
89
-
90
- directml_enabled = False
91
- if args.directml is not None:
92
- import torch_directml
93
- directml_enabled = True
94
- device_index = args.directml
95
- if device_index < 0:
96
- directml_device = torch_directml.device()
97
- else:
98
- directml_device = torch_directml.device(device_index)
99
- logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
100
- # torch_directml.disable_tiled_resources(True)
101
- lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
102
-
103
- try:
104
- import intel_extension_for_pytorch as ipex # noqa: F401
105
- except:
106
- pass
107
-
108
- try:
109
- _ = torch.xpu.device_count()
110
- xpu_available = torch.xpu.is_available()
111
- except:
112
- xpu_available = False
113
-
114
- try:
115
- if torch.backends.mps.is_available():
116
- cpu_state = CPUState.MPS
117
- import torch.mps
118
- except:
119
- pass
120
-
121
- try:
122
- import torch_npu # noqa: F401
123
- _ = torch.npu.device_count()
124
- npu_available = torch.npu.is_available()
125
- except:
126
- npu_available = False
127
-
128
- try:
129
- import torch_mlu # noqa: F401
130
- _ = torch.mlu.device_count()
131
- mlu_available = torch.mlu.is_available()
132
- except:
133
- mlu_available = False
134
-
135
- try:
136
- ixuca_available = hasattr(torch, "corex")
137
- except:
138
- ixuca_available = False
139
-
140
- if args.cpu:
141
- cpu_state = CPUState.CPU
142
-
143
- def is_intel_xpu():
144
- global cpu_state
145
- global xpu_available
146
- if cpu_state == CPUState.GPU:
147
- if xpu_available:
148
- return True
149
- return False
150
-
151
- def is_ascend_npu():
152
- global npu_available
153
- if npu_available:
154
- return True
155
- return False
156
-
157
- def is_mlu():
158
- global mlu_available
159
- if mlu_available:
160
- return True
161
- return False
162
-
163
- def is_ixuca():
164
- global ixuca_available
165
- if ixuca_available:
166
- return True
167
- return False
168
-
169
- def get_torch_device():
170
- global directml_enabled
171
- global cpu_state
172
- if directml_enabled:
173
- global directml_device
174
- return directml_device
175
- if cpu_state == CPUState.MPS:
176
- return torch.device("mps")
177
- if cpu_state == CPUState.CPU:
178
- return torch.device("cpu")
179
- else:
180
- if is_intel_xpu():
181
- return torch.device("xpu", torch.xpu.current_device())
182
- elif is_ascend_npu():
183
- return torch.device("npu", torch.npu.current_device())
184
- elif is_mlu():
185
- return torch.device("mlu", torch.mlu.current_device())
186
- else:
187
- return torch.device(torch.cuda.current_device())
188
-
189
- def get_total_memory(dev=None, torch_total_too=False):
190
- global directml_enabled
191
- if dev is None:
192
- dev = get_torch_device()
193
-
194
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
195
- mem_total = psutil.virtual_memory().total
196
- mem_total_torch = mem_total
197
- else:
198
- if directml_enabled:
199
- mem_total = 1024 * 1024 * 1024 #TODO
200
- mem_total_torch = mem_total
201
- elif is_intel_xpu():
202
- stats = torch.xpu.memory_stats(dev)
203
- mem_reserved = stats['reserved_bytes.all.current']
204
- mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
205
- mem_total_torch = mem_reserved
206
- mem_total = mem_total_xpu
207
- elif is_ascend_npu():
208
- stats = torch.npu.memory_stats(dev)
209
- mem_reserved = stats['reserved_bytes.all.current']
210
- _, mem_total_npu = torch.npu.mem_get_info(dev)
211
- mem_total_torch = mem_reserved
212
- mem_total = mem_total_npu
213
- elif is_mlu():
214
- stats = torch.mlu.memory_stats(dev)
215
- mem_reserved = stats['reserved_bytes.all.current']
216
- _, mem_total_mlu = torch.mlu.mem_get_info(dev)
217
- mem_total_torch = mem_reserved
218
- mem_total = mem_total_mlu
219
- else:
220
- stats = torch.cuda.memory_stats(dev)
221
- mem_reserved = stats['reserved_bytes.all.current']
222
- _, mem_total_cuda = torch.cuda.mem_get_info(dev)
223
- mem_total_torch = mem_reserved
224
- mem_total = mem_total_cuda
225
-
226
- if torch_total_too:
227
- return (mem_total, mem_total_torch)
228
- else:
229
- return mem_total
230
-
231
- def mac_version():
232
- try:
233
- return tuple(int(n) for n in platform.mac_ver()[0].split("."))
234
- except:
235
- return None
236
-
237
- total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
238
- total_ram = psutil.virtual_memory().total / (1024 * 1024)
239
- logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
240
-
241
- try:
242
- logging.info("pytorch version: {}".format(torch_version))
243
- mac_ver = mac_version()
244
- if mac_ver is not None:
245
- logging.info("Mac Version {}".format(mac_ver))
246
- except:
247
- pass
248
-
249
- try:
250
- OOM_EXCEPTION = torch.cuda.OutOfMemoryError
251
- except:
252
- OOM_EXCEPTION = Exception
253
-
254
- XFORMERS_VERSION = ""
255
- XFORMERS_ENABLED_VAE = True
256
- if args.disable_xformers:
257
- XFORMERS_IS_AVAILABLE = False
258
- else:
259
- try:
260
- import xformers
261
- import xformers.ops
262
- XFORMERS_IS_AVAILABLE = True
263
- try:
264
- XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
265
- except:
266
- pass
267
- try:
268
- XFORMERS_VERSION = xformers.version.__version__
269
- logging.info("xformers version: {}".format(XFORMERS_VERSION))
270
- if XFORMERS_VERSION.startswith("0.0.18"):
271
- logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
272
- logging.warning("Please downgrade or upgrade xformers to a different version.\n")
273
- XFORMERS_ENABLED_VAE = False
274
- except:
275
- pass
276
- except:
277
- XFORMERS_IS_AVAILABLE = False
278
-
279
- def is_nvidia():
280
- global cpu_state
281
- if cpu_state == CPUState.GPU:
282
- if torch.version.cuda:
283
- return True
284
- return False
285
-
286
- def is_amd():
287
- global cpu_state
288
- if cpu_state == CPUState.GPU:
289
- if torch.version.hip:
290
- return True
291
- return False
292
-
293
- def amd_min_version(device=None, min_rdna_version=0):
294
- if not is_amd():
295
- return False
296
-
297
- if is_device_cpu(device):
298
- return False
299
-
300
- arch = torch.cuda.get_device_properties(device).gcnArchName
301
- if arch.startswith('gfx') and len(arch) == 7:
302
- try:
303
- cmp_rdna_version = int(arch[4]) + 2
304
- except:
305
- cmp_rdna_version = 0
306
- if cmp_rdna_version >= min_rdna_version:
307
- return True
308
-
309
- return False
310
-
311
- MIN_WEIGHT_MEMORY_RATIO = 0.4
312
- if is_nvidia():
313
- MIN_WEIGHT_MEMORY_RATIO = 0.0
314
-
315
- ENABLE_PYTORCH_ATTENTION = False
316
- if args.use_pytorch_cross_attention:
317
- ENABLE_PYTORCH_ATTENTION = True
318
- XFORMERS_IS_AVAILABLE = False
319
-
320
- try:
321
- if is_nvidia():
322
- if torch_version_numeric[0] >= 2:
323
- if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
324
- ENABLE_PYTORCH_ATTENTION = True
325
- if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
326
- if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
327
- ENABLE_PYTORCH_ATTENTION = True
328
- except:
329
- pass
330
-
331
-
332
- SUPPORT_FP8_OPS = args.supports_fp8_compute
333
- try:
334
- if is_amd():
335
- try:
336
- rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
337
- except:
338
- rocm_version = (6, -1)
339
- arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
340
- logging.info("AMD arch: {}".format(arch))
341
- logging.info("ROCm version: {}".format(rocm_version))
342
- if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
343
- if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
344
- if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
345
- if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
346
- ENABLE_PYTORCH_ATTENTION = True
347
- # if torch_version_numeric >= (2, 8):
348
- # if any((a in arch) for a in ["gfx1201"]):
349
- # ENABLE_PYTORCH_ATTENTION = True
350
- if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
351
- if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
352
- SUPPORT_FP8_OPS = True
353
-
354
- except:
355
- pass
356
-
357
-
358
- if ENABLE_PYTORCH_ATTENTION:
359
- torch.backends.cuda.enable_math_sdp(True)
360
- torch.backends.cuda.enable_flash_sdp(True)
361
- torch.backends.cuda.enable_mem_efficient_sdp(True)
362
-
363
-
364
- PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
365
- try:
366
- if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
367
- torch.backends.cuda.matmul.allow_fp16_accumulation = True
368
- PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
369
- logging.info("Enabled fp16 accumulation.")
370
- except:
371
- pass
372
-
373
- try:
374
- if torch_version_numeric >= (2, 5):
375
- torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
376
- except:
377
- logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
378
-
379
- if args.lowvram:
380
- set_vram_to = VRAMState.LOW_VRAM
381
- lowvram_available = True
382
- elif args.novram:
383
- set_vram_to = VRAMState.NO_VRAM
384
- elif args.highvram or args.gpu_only:
385
- vram_state = VRAMState.HIGH_VRAM
386
-
387
- FORCE_FP32 = False
388
- if args.force_fp32:
389
- logging.info("Forcing FP32, if this improves things please report it.")
390
- FORCE_FP32 = True
391
-
392
- if lowvram_available:
393
- if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
394
- vram_state = set_vram_to
395
-
396
-
397
- if cpu_state != CPUState.GPU:
398
- vram_state = VRAMState.DISABLED
399
-
400
- if cpu_state == CPUState.MPS:
401
- vram_state = VRAMState.SHARED
402
-
403
- logging.info(f"Set vram state to: {vram_state.name}")
404
-
405
- DISABLE_SMART_MEMORY = args.disable_smart_memory
406
-
407
- if DISABLE_SMART_MEMORY:
408
- logging.info("Disabling smart memory management")
409
-
410
- def get_torch_device_name(device):
411
- if hasattr(device, 'type'):
412
- if device.type == "cuda":
413
- try:
414
- allocator_backend = torch.cuda.get_allocator_backend()
415
- except:
416
- allocator_backend = ""
417
- return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
418
- elif device.type == "xpu":
419
- return "{} {}".format(device, torch.xpu.get_device_name(device))
420
- else:
421
- return "{}".format(device.type)
422
- elif is_intel_xpu():
423
- return "{} {}".format(device, torch.xpu.get_device_name(device))
424
- elif is_ascend_npu():
425
- return "{} {}".format(device, torch.npu.get_device_name(device))
426
- elif is_mlu():
427
- return "{} {}".format(device, torch.mlu.get_device_name(device))
428
- else:
429
- return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
430
-
431
- try:
432
- logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
433
- except:
434
- logging.warning("Could not pick default device.")
435
-
436
-
437
- current_loaded_models = []
438
-
439
- def module_size(module):
440
- module_mem = 0
441
- sd = module.state_dict()
442
- for k in sd:
443
- t = sd[k]
444
- module_mem += t.nelement() * t.element_size()
445
- return module_mem
446
-
447
- class LoadedModel:
448
- def __init__(self, model):
449
- self._set_model(model)
450
- self.device = model.load_device
451
- self.real_model = None
452
- self.currently_used = True
453
- self.model_finalizer = None
454
- self._patcher_finalizer = None
455
-
456
- def _set_model(self, model):
457
- self._model = weakref.ref(model)
458
- if model.parent is not None:
459
- self._parent_model = weakref.ref(model.parent)
460
- self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
461
-
462
- def _switch_parent(self):
463
- model = self._parent_model()
464
- if model is not None:
465
- self._set_model(model)
466
-
467
- @property
468
- def model(self):
469
- return self._model()
470
-
471
- def model_memory(self):
472
- return self.model.model_size()
473
-
474
- def model_loaded_memory(self):
475
- return self.model.loaded_size()
476
-
477
- def model_offloaded_memory(self):
478
- return self.model.model_size() - self.model.loaded_size()
479
-
480
- def model_memory_required(self, device):
481
- if device == self.model.current_loaded_device():
482
- return self.model_offloaded_memory()
483
- else:
484
- return self.model_memory()
485
-
486
- def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
487
- self.model.model_patches_to(self.device)
488
- self.model.model_patches_to(self.model.model_dtype())
489
-
490
- # if self.model.loaded_size() > 0:
491
- use_more_vram = lowvram_model_memory
492
- if use_more_vram == 0:
493
- use_more_vram = 1e32
494
- self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
495
- real_model = self.model.model
496
-
497
- if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
498
- with torch.no_grad():
499
- real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
500
-
501
- self.real_model = weakref.ref(real_model)
502
- self.model_finalizer = weakref.finalize(real_model, cleanup_models)
503
- return real_model
504
-
505
- def should_reload_model(self, force_patch_weights=False):
506
- if force_patch_weights and self.model.lowvram_patch_counter() > 0:
507
- return True
508
- return False
509
-
510
- def model_unload(self, memory_to_free=None, unpatch_weights=True):
511
- if memory_to_free is not None:
512
- if memory_to_free < self.model.loaded_size():
513
- freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
514
- if freed >= memory_to_free:
515
- return False
516
- self.model.detach(unpatch_weights)
517
- self.model_finalizer.detach()
518
- self.model_finalizer = None
519
- self.real_model = None
520
- return True
521
-
522
- def model_use_more_vram(self, extra_memory, force_patch_weights=False):
523
- return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
524
-
525
- def __eq__(self, other):
526
- return self.model is other.model
527
-
528
- def __del__(self):
529
- if self._patcher_finalizer is not None:
530
- self._patcher_finalizer.detach()
531
-
532
- def is_dead(self):
533
- return self.real_model() is not None and self.model is None
534
-
535
-
536
- def use_more_memory(extra_memory, loaded_models, device):
537
- for m in loaded_models:
538
- if m.device == device:
539
- extra_memory -= m.model_use_more_vram(extra_memory)
540
- if extra_memory <= 0:
541
- break
542
-
543
- def offloaded_memory(loaded_models, device):
544
- offloaded_mem = 0
545
- for m in loaded_models:
546
- if m.device == device:
547
- offloaded_mem += m.model_offloaded_memory()
548
- return offloaded_mem
549
-
550
- WINDOWS = any(platform.win32_ver())
551
-
552
- EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
553
- if WINDOWS:
554
- EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
555
- if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
556
- EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
557
-
558
- if args.reserve_vram is not None:
559
- EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
560
- logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
561
-
562
- def extra_reserved_memory():
563
- return EXTRA_RESERVED_VRAM
564
-
565
- def minimum_inference_memory():
566
- return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
567
-
568
- def free_memory(memory_required, device, keep_loaded=[]):
569
- cleanup_models_gc()
570
- unloaded_model = []
571
- can_unload = []
572
- unloaded_models = []
573
-
574
- for i in range(len(current_loaded_models) -1, -1, -1):
575
- shift_model = current_loaded_models[i]
576
- if shift_model.device == device:
577
- if shift_model not in keep_loaded and not shift_model.is_dead():
578
- can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
579
- shift_model.currently_used = False
580
-
581
- for x in sorted(can_unload):
582
- i = x[-1]
583
- memory_to_free = None
584
- if not DISABLE_SMART_MEMORY:
585
- free_mem = get_free_memory(device)
586
- if free_mem > memory_required:
587
- break
588
- memory_to_free = memory_required - free_mem
589
- logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
590
- if current_loaded_models[i].model_unload(memory_to_free):
591
- unloaded_model.append(i)
592
-
593
- for i in sorted(unloaded_model, reverse=True):
594
- unloaded_models.append(current_loaded_models.pop(i))
595
-
596
- if len(unloaded_model) > 0:
597
- soft_empty_cache()
598
- else:
599
- if vram_state != VRAMState.HIGH_VRAM:
600
- mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
601
- if mem_free_torch > mem_free_total * 0.25:
602
- soft_empty_cache()
603
- return unloaded_models
604
-
605
- def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
606
- cleanup_models_gc()
607
- global vram_state
608
-
609
- inference_memory = minimum_inference_memory()
610
- extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
611
- if minimum_memory_required is None:
612
- minimum_memory_required = extra_mem
613
- else:
614
- minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
615
-
616
- models_temp = set()
617
- for m in models:
618
- models_temp.add(m)
619
- for mm in m.model_patches_models():
620
- models_temp.add(mm)
621
-
622
- models = models_temp
623
-
624
- models_to_load = []
625
-
626
- for x in models:
627
- loaded_model = LoadedModel(x)
628
- try:
629
- loaded_model_index = current_loaded_models.index(loaded_model)
630
- except:
631
- loaded_model_index = None
632
-
633
- if loaded_model_index is not None:
634
- loaded = current_loaded_models[loaded_model_index]
635
- loaded.currently_used = True
636
- models_to_load.append(loaded)
637
- else:
638
- if hasattr(x, "model"):
639
- logging.info(f"Requested to load {x.model.__class__.__name__}")
640
- models_to_load.append(loaded_model)
641
-
642
- for loaded_model in models_to_load:
643
- to_unload = []
644
- for i in range(len(current_loaded_models)):
645
- if loaded_model.model.is_clone(current_loaded_models[i].model):
646
- to_unload = [i] + to_unload
647
- for i in to_unload:
648
- current_loaded_models.pop(i).model.detach(unpatch_all=False)
649
-
650
- total_memory_required = {}
651
- for loaded_model in models_to_load:
652
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
653
-
654
- for device in total_memory_required:
655
- if device != torch.device("cpu"):
656
- free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
657
-
658
- for device in total_memory_required:
659
- if device != torch.device("cpu"):
660
- free_mem = get_free_memory(device)
661
- if free_mem < minimum_memory_required:
662
- models_l = free_memory(minimum_memory_required, device)
663
- logging.info("{} models unloaded.".format(len(models_l)))
664
-
665
- for loaded_model in models_to_load:
666
- model = loaded_model.model
667
- torch_dev = model.load_device
668
- if is_device_cpu(torch_dev):
669
- vram_set_state = VRAMState.DISABLED
670
- else:
671
- vram_set_state = vram_state
672
- lowvram_model_memory = 0
673
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
674
- loaded_memory = loaded_model.model_loaded_memory()
675
- current_free_mem = get_free_memory(torch_dev) + loaded_memory
676
-
677
- lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
678
- lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
679
-
680
- if vram_set_state == VRAMState.NO_VRAM:
681
- lowvram_model_memory = 0.1
682
-
683
- loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
684
- current_loaded_models.insert(0, loaded_model)
685
- return
686
-
687
- def load_model_gpu(model):
688
- return load_models_gpu([model])
689
-
690
- def loaded_models(only_currently_used=False):
691
- output = []
692
- for m in current_loaded_models:
693
- if only_currently_used:
694
- if not m.currently_used:
695
- continue
696
-
697
- output.append(m.model)
698
- return output
699
-
700
-
701
- def cleanup_models_gc():
702
- do_gc = False
703
- for i in range(len(current_loaded_models)):
704
- cur = current_loaded_models[i]
705
- if cur.is_dead():
706
- logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
707
- do_gc = True
708
- break
709
-
710
- if do_gc:
711
- gc.collect()
712
- soft_empty_cache()
713
-
714
- for i in range(len(current_loaded_models)):
715
- cur = current_loaded_models[i]
716
- if cur.is_dead():
717
- logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
718
-
719
-
720
-
721
- def cleanup_models():
722
- to_delete = []
723
- for i in range(len(current_loaded_models)):
724
- if current_loaded_models[i].real_model() is None:
725
- to_delete = [i] + to_delete
726
-
727
- for i in to_delete:
728
- x = current_loaded_models.pop(i)
729
- del x
730
-
731
- def dtype_size(dtype):
732
- dtype_size = 4
733
- if dtype == torch.float16 or dtype == torch.bfloat16:
734
- dtype_size = 2
735
- elif dtype == torch.float32:
736
- dtype_size = 4
737
- else:
738
- try:
739
- dtype_size = dtype.itemsize
740
- except: #Old pytorch doesn't have .itemsize
741
- pass
742
- return dtype_size
743
-
744
- def unet_offload_device():
745
- if vram_state == VRAMState.HIGH_VRAM:
746
- return get_torch_device()
747
- else:
748
- return torch.device("cpu")
749
-
750
- def unet_inital_load_device(parameters, dtype):
751
- torch_dev = get_torch_device()
752
- if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
753
- return torch_dev
754
-
755
- cpu_dev = torch.device("cpu")
756
- if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
757
- return cpu_dev
758
-
759
- model_size = dtype_size(dtype) * parameters
760
-
761
- mem_dev = get_free_memory(torch_dev)
762
- mem_cpu = get_free_memory(cpu_dev)
763
- if mem_dev > mem_cpu and model_size < mem_dev:
764
- return torch_dev
765
- else:
766
- return cpu_dev
767
-
768
- def maximum_vram_for_weights(device=None):
769
- return (get_total_memory(device) * 0.88 - minimum_inference_memory())
770
-
771
- def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
772
- if model_params < 0:
773
- model_params = 1000000000000000000000
774
- if args.fp32_unet:
775
- return torch.float32
776
- if args.fp64_unet:
777
- return torch.float64
778
- if args.bf16_unet:
779
- return torch.bfloat16
780
- if args.fp16_unet:
781
- return torch.float16
782
- if args.fp8_e4m3fn_unet:
783
- return torch.float8_e4m3fn
784
- if args.fp8_e5m2_unet:
785
- return torch.float8_e5m2
786
- if args.fp8_e8m0fnu_unet:
787
- return torch.float8_e8m0fnu
788
-
789
- fp8_dtype = None
790
- if weight_dtype in FLOAT8_TYPES:
791
- fp8_dtype = weight_dtype
792
-
793
- if fp8_dtype is not None:
794
- if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
795
- return fp8_dtype
796
-
797
- free_model_memory = maximum_vram_for_weights(device)
798
- if model_params * 2 > free_model_memory:
799
- return fp8_dtype
800
-
801
- if PRIORITIZE_FP16 or weight_dtype == torch.float16:
802
- if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
803
- return torch.float16
804
-
805
- for dt in supported_dtypes:
806
- if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
807
- if torch.float16 in supported_dtypes:
808
- return torch.float16
809
- if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
810
- if torch.bfloat16 in supported_dtypes:
811
- return torch.bfloat16
812
-
813
- for dt in supported_dtypes:
814
- if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
815
- if torch.float16 in supported_dtypes:
816
- return torch.float16
817
- if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
818
- if torch.bfloat16 in supported_dtypes:
819
- return torch.bfloat16
820
-
821
- return torch.float32
822
-
823
- # None means no manual cast
824
- def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
825
- if weight_dtype == torch.float32 or weight_dtype == torch.float64:
826
- return None
827
-
828
- fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
829
- if fp16_supported and weight_dtype == torch.float16:
830
- return None
831
-
832
- bf16_supported = should_use_bf16(inference_device)
833
- if bf16_supported and weight_dtype == torch.bfloat16:
834
- return None
835
-
836
- fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
837
- if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
838
- return torch.float16
839
-
840
- for dt in supported_dtypes:
841
- if dt == torch.float16 and fp16_supported:
842
- return torch.float16
843
- if dt == torch.bfloat16 and bf16_supported:
844
- return torch.bfloat16
845
-
846
- return torch.float32
847
-
848
- def text_encoder_offload_device():
849
- if args.gpu_only:
850
- return get_torch_device()
851
- else:
852
- return torch.device("cpu")
853
-
854
- def text_encoder_device():
855
- if args.gpu_only:
856
- return get_torch_device()
857
- elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
858
- if should_use_fp16(prioritize_performance=False):
859
- return get_torch_device()
860
- else:
861
- return torch.device("cpu")
862
- else:
863
- return torch.device("cpu")
864
-
865
- def text_encoder_initial_device(load_device, offload_device, model_size=0):
866
- if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
867
- return offload_device
868
-
869
- if is_device_mps(load_device):
870
- return load_device
871
-
872
- mem_l = get_free_memory(load_device)
873
- mem_o = get_free_memory(offload_device)
874
- if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
875
- return load_device
876
- else:
877
- return offload_device
878
-
879
- def text_encoder_dtype(device=None):
880
- if args.fp8_e4m3fn_text_enc:
881
- return torch.float8_e4m3fn
882
- elif args.fp8_e5m2_text_enc:
883
- return torch.float8_e5m2
884
- elif args.fp16_text_enc:
885
- return torch.float16
886
- elif args.bf16_text_enc:
887
- return torch.bfloat16
888
- elif args.fp32_text_enc:
889
- return torch.float32
890
-
891
- if is_device_cpu(device):
892
- return torch.float16
893
-
894
- return torch.float16
895
-
896
-
897
- def intermediate_device():
898
- if args.gpu_only:
899
- return get_torch_device()
900
- else:
901
- return torch.device("cpu")
902
-
903
- def vae_device():
904
- if args.cpu_vae:
905
- return torch.device("cpu")
906
- return get_torch_device()
907
-
908
- def vae_offload_device():
909
- if args.gpu_only:
910
- return get_torch_device()
911
- else:
912
- return torch.device("cpu")
913
-
914
- def vae_dtype(device=None, allowed_dtypes=[]):
915
- if args.fp16_vae:
916
- return torch.float16
917
- elif args.bf16_vae:
918
- return torch.bfloat16
919
- elif args.fp32_vae:
920
- return torch.float32
921
-
922
- for d in allowed_dtypes:
923
- if d == torch.float16 and should_use_fp16(device):
924
- return d
925
-
926
- # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
927
- # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
928
- # also a problem on RDNA4 except fp32 is also slow there.
929
- # This is due to large bf16 convolutions being extremely slow.
930
- if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
931
- return d
932
-
933
- return torch.float32
934
-
935
- def get_autocast_device(dev):
936
- if hasattr(dev, 'type'):
937
- return dev.type
938
- return "cuda"
939
-
940
- def supports_dtype(device, dtype): #TODO
941
- if dtype == torch.float32:
942
- return True
943
- if is_device_cpu(device):
944
- return False
945
- if dtype == torch.float16:
946
- return True
947
- if dtype == torch.bfloat16:
948
- return True
949
- return False
950
-
951
- def supports_cast(device, dtype): #TODO
952
- if dtype == torch.float32:
953
- return True
954
- if dtype == torch.float16:
955
- return True
956
- if directml_enabled: #TODO: test this
957
- return False
958
- if dtype == torch.bfloat16:
959
- return True
960
- if is_device_mps(device):
961
- return False
962
- if dtype == torch.float8_e4m3fn:
963
- return True
964
- if dtype == torch.float8_e5m2:
965
- return True
966
- return False
967
-
968
- def pick_weight_dtype(dtype, fallback_dtype, device=None):
969
- if dtype is None:
970
- dtype = fallback_dtype
971
- elif dtype_size(dtype) > dtype_size(fallback_dtype):
972
- dtype = fallback_dtype
973
-
974
- if not supports_cast(device, dtype):
975
- dtype = fallback_dtype
976
-
977
- return dtype
978
-
979
- def device_supports_non_blocking(device):
980
- if args.force_non_blocking:
981
- return True
982
- if is_device_mps(device):
983
- return False #pytorch bug? mps doesn't support non blocking
984
- if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
985
- return False
986
- if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
987
- return False
988
- if directml_enabled:
989
- return False
990
- return True
991
-
992
- def device_should_use_non_blocking(device):
993
- if not device_supports_non_blocking(device):
994
- return False
995
- return False
996
- # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
997
-
998
- def force_channels_last():
999
- if args.force_channels_last:
1000
- return True
1001
-
1002
- #TODO
1003
- return False
1004
-
1005
-
1006
- STREAMS = {}
1007
- NUM_STREAMS = 1
1008
- if args.async_offload:
1009
- NUM_STREAMS = 2
1010
- logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
1011
-
1012
- stream_counters = {}
1013
- def get_offload_stream(device):
1014
- stream_counter = stream_counters.get(device, 0)
1015
- if NUM_STREAMS <= 1:
1016
- return None
1017
-
1018
- if device in STREAMS:
1019
- ss = STREAMS[device]
1020
- s = ss[stream_counter]
1021
- stream_counter = (stream_counter + 1) % len(ss)
1022
- if is_device_cuda(device):
1023
- ss[stream_counter].wait_stream(torch.cuda.current_stream())
1024
- elif is_device_xpu(device):
1025
- ss[stream_counter].wait_stream(torch.xpu.current_stream())
1026
- stream_counters[device] = stream_counter
1027
- return s
1028
- elif is_device_cuda(device):
1029
- ss = []
1030
- for k in range(NUM_STREAMS):
1031
- ss.append(torch.cuda.Stream(device=device, priority=0))
1032
- STREAMS[device] = ss
1033
- s = ss[stream_counter]
1034
- stream_counter = (stream_counter + 1) % len(ss)
1035
- stream_counters[device] = stream_counter
1036
- return s
1037
- elif is_device_xpu(device):
1038
- ss = []
1039
- for k in range(NUM_STREAMS):
1040
- ss.append(torch.xpu.Stream(device=device, priority=0))
1041
- STREAMS[device] = ss
1042
- s = ss[stream_counter]
1043
- stream_counter = (stream_counter + 1) % len(ss)
1044
- stream_counters[device] = stream_counter
1045
- return s
1046
- return None
1047
-
1048
- def sync_stream(device, stream):
1049
- if stream is None:
1050
- return
1051
- if is_device_cuda(device):
1052
- torch.cuda.current_stream().wait_stream(stream)
1053
- elif is_device_xpu(device):
1054
- torch.xpu.current_stream().wait_stream(stream)
1055
-
1056
- def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
1057
- if device is None or weight.device == device:
1058
- if not copy:
1059
- if dtype is None or weight.dtype == dtype:
1060
- return weight
1061
- if stream is not None:
1062
- with stream:
1063
- return weight.to(dtype=dtype, copy=copy)
1064
- return weight.to(dtype=dtype, copy=copy)
1065
-
1066
- if stream is not None:
1067
- with stream:
1068
- r = torch.empty_like(weight, dtype=dtype, device=device)
1069
- r.copy_(weight, non_blocking=non_blocking)
1070
- else:
1071
- r = torch.empty_like(weight, dtype=dtype, device=device)
1072
- r.copy_(weight, non_blocking=non_blocking)
1073
- return r
1074
-
1075
- def cast_to_device(tensor, device, dtype, copy=False):
1076
- non_blocking = device_supports_non_blocking(device)
1077
- return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
1078
-
1079
- def sage_attention_enabled():
1080
- return args.use_sage_attention
1081
-
1082
- def flash_attention_enabled():
1083
- return args.use_flash_attention
1084
-
1085
- def xformers_enabled():
1086
- global directml_enabled
1087
- global cpu_state
1088
- if cpu_state != CPUState.GPU:
1089
- return False
1090
- if is_intel_xpu():
1091
- return False
1092
- if is_ascend_npu():
1093
- return False
1094
- if is_mlu():
1095
- return False
1096
- if is_ixuca():
1097
- return False
1098
- if directml_enabled:
1099
- return False
1100
- return XFORMERS_IS_AVAILABLE
1101
-
1102
-
1103
- def xformers_enabled_vae():
1104
- enabled = xformers_enabled()
1105
- if not enabled:
1106
- return False
1107
-
1108
- return XFORMERS_ENABLED_VAE
1109
-
1110
- def pytorch_attention_enabled():
1111
- global ENABLE_PYTORCH_ATTENTION
1112
- return ENABLE_PYTORCH_ATTENTION
1113
-
1114
- def pytorch_attention_enabled_vae():
1115
- if is_amd():
1116
- return False # enabling pytorch attention on AMD currently causes crash when doing high res
1117
- return pytorch_attention_enabled()
1118
-
1119
- def pytorch_attention_flash_attention():
1120
- global ENABLE_PYTORCH_ATTENTION
1121
- if ENABLE_PYTORCH_ATTENTION:
1122
- #TODO: more reliable way of checking for flash attention?
1123
- if is_nvidia():
1124
- return True
1125
- if is_intel_xpu():
1126
- return True
1127
- if is_ascend_npu():
1128
- return True
1129
- if is_mlu():
1130
- return True
1131
- if is_amd():
1132
- return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
1133
- if is_ixuca():
1134
- return True
1135
- return False
1136
-
1137
- def force_upcast_attention_dtype():
1138
- upcast = args.force_upcast_attention
1139
-
1140
- macos_version = mac_version()
1141
- if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
1142
- upcast = True
1143
-
1144
- if upcast:
1145
- return {torch.float16: torch.float32}
1146
- else:
1147
- return None
1148
-
1149
- def get_free_memory(dev=None, torch_free_too=False):
1150
- global directml_enabled
1151
- if dev is None:
1152
- dev = get_torch_device()
1153
-
1154
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
1155
- mem_free_total = psutil.virtual_memory().available
1156
- mem_free_torch = mem_free_total
1157
- else:
1158
- if directml_enabled:
1159
- mem_free_total = 1024 * 1024 * 1024 #TODO
1160
- mem_free_torch = mem_free_total
1161
- elif is_intel_xpu():
1162
- stats = torch.xpu.memory_stats(dev)
1163
- mem_active = stats['active_bytes.all.current']
1164
- mem_reserved = stats['reserved_bytes.all.current']
1165
- mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
1166
- mem_free_torch = mem_reserved - mem_active
1167
- mem_free_total = mem_free_xpu + mem_free_torch
1168
- elif is_ascend_npu():
1169
- stats = torch.npu.memory_stats(dev)
1170
- mem_active = stats['active_bytes.all.current']
1171
- mem_reserved = stats['reserved_bytes.all.current']
1172
- mem_free_npu, _ = torch.npu.mem_get_info(dev)
1173
- mem_free_torch = mem_reserved - mem_active
1174
- mem_free_total = mem_free_npu + mem_free_torch
1175
- elif is_mlu():
1176
- stats = torch.mlu.memory_stats(dev)
1177
- mem_active = stats['active_bytes.all.current']
1178
- mem_reserved = stats['reserved_bytes.all.current']
1179
- mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
1180
- mem_free_torch = mem_reserved - mem_active
1181
- mem_free_total = mem_free_mlu + mem_free_torch
1182
- else:
1183
- stats = torch.cuda.memory_stats(dev)
1184
- mem_active = stats['active_bytes.all.current']
1185
- mem_reserved = stats['reserved_bytes.all.current']
1186
- mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
1187
- mem_free_torch = mem_reserved - mem_active
1188
- mem_free_total = mem_free_cuda + mem_free_torch
1189
-
1190
- if torch_free_too:
1191
- return (mem_free_total, mem_free_torch)
1192
- else:
1193
- return mem_free_total
1194
-
1195
- def cpu_mode():
1196
- global cpu_state
1197
- return cpu_state == CPUState.CPU
1198
-
1199
- def mps_mode():
1200
- global cpu_state
1201
- return cpu_state == CPUState.MPS
1202
-
1203
- def is_device_type(device, type):
1204
- if hasattr(device, 'type'):
1205
- if (device.type == type):
1206
- return True
1207
- return False
1208
-
1209
- def is_device_cpu(device):
1210
- return is_device_type(device, 'cpu')
1211
-
1212
- def is_device_mps(device):
1213
- return is_device_type(device, 'mps')
1214
-
1215
- def is_device_xpu(device):
1216
- return is_device_type(device, 'xpu')
1217
-
1218
- def is_device_cuda(device):
1219
- return is_device_type(device, 'cuda')
1220
-
1221
- def is_directml_enabled():
1222
- global directml_enabled
1223
- if directml_enabled:
1224
- return True
1225
-
1226
- return False
1227
-
1228
- def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1229
- if device is not None:
1230
- if is_device_cpu(device):
1231
- return False
1232
-
1233
- if args.force_fp16:
1234
- return True
1235
-
1236
- if FORCE_FP32:
1237
- return False
1238
-
1239
- if is_directml_enabled():
1240
- return True
1241
-
1242
- if (device is not None and is_device_mps(device)) or mps_mode():
1243
- return True
1244
-
1245
- if cpu_mode():
1246
- return False
1247
-
1248
- if is_intel_xpu():
1249
- if torch_version_numeric < (2, 3):
1250
- return True
1251
- else:
1252
- return torch.xpu.get_device_properties(device).has_fp16
1253
-
1254
- if is_ascend_npu():
1255
- return True
1256
-
1257
- if is_mlu():
1258
- return True
1259
-
1260
- if is_ixuca():
1261
- return True
1262
-
1263
- if torch.version.hip:
1264
- return True
1265
-
1266
- props = torch.cuda.get_device_properties(device)
1267
- if props.major >= 8:
1268
- return True
1269
-
1270
- if props.major < 6:
1271
- return False
1272
-
1273
- #FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
1274
- nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
1275
- for x in nvidia_10_series:
1276
- if x in props.name.lower():
1277
- if WINDOWS or manual_cast:
1278
- return True
1279
- else:
1280
- return False #weird linux behavior where fp32 is faster
1281
-
1282
- if manual_cast:
1283
- free_model_memory = maximum_vram_for_weights(device)
1284
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1285
- return True
1286
-
1287
- if props.major < 7:
1288
- return False
1289
-
1290
- #FP16 is just broken on these cards
1291
- nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
1292
- for x in nvidia_16_series:
1293
- if x in props.name:
1294
- return False
1295
-
1296
- return True
1297
-
1298
- def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1299
- if device is not None:
1300
- if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
1301
- return False
1302
-
1303
- if FORCE_FP32:
1304
- return False
1305
-
1306
- if directml_enabled:
1307
- return False
1308
-
1309
- if (device is not None and is_device_mps(device)) or mps_mode():
1310
- if mac_version() < (14,):
1311
- return False
1312
- return True
1313
-
1314
- if cpu_mode():
1315
- return False
1316
-
1317
- if is_intel_xpu():
1318
- if torch_version_numeric < (2, 3):
1319
- return True
1320
- else:
1321
- return torch.xpu.is_bf16_supported()
1322
-
1323
- if is_ascend_npu():
1324
- return True
1325
-
1326
- if is_ixuca():
1327
- return True
1328
-
1329
- if is_amd():
1330
- arch = torch.cuda.get_device_properties(device).gcnArchName
1331
- if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
1332
- if manual_cast:
1333
- return True
1334
- return False
1335
-
1336
- props = torch.cuda.get_device_properties(device)
1337
-
1338
- if is_mlu():
1339
- if props.major > 3:
1340
- return True
1341
-
1342
- if props.major >= 8:
1343
- return True
1344
-
1345
- bf16_works = torch.cuda.is_bf16_supported()
1346
-
1347
- if bf16_works and manual_cast:
1348
- free_model_memory = maximum_vram_for_weights(device)
1349
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1350
- return True
1351
-
1352
- return False
1353
-
1354
- def supports_fp8_compute(device=None):
1355
- if SUPPORT_FP8_OPS:
1356
- return True
1357
-
1358
- if not is_nvidia():
1359
- return False
1360
-
1361
- props = torch.cuda.get_device_properties(device)
1362
- if props.major >= 9:
1363
- return True
1364
- if props.major < 8:
1365
- return False
1366
- if props.minor < 9:
1367
- return False
1368
-
1369
- if torch_version_numeric < (2, 3):
1370
- return False
1371
-
1372
- if WINDOWS:
1373
- if torch_version_numeric < (2, 4):
1374
- return False
1375
-
1376
- return True
1377
-
1378
- def extended_fp16_support():
1379
- # TODO: check why some models work with fp16 on newer torch versions but not on older
1380
- if torch_version_numeric < (2, 7):
1381
- return False
1382
-
1383
- return True
1384
-
1385
- def soft_empty_cache(force=False):
1386
- global cpu_state
1387
- if cpu_state == CPUState.MPS:
1388
- torch.mps.empty_cache()
1389
- elif is_intel_xpu():
1390
- torch.xpu.empty_cache()
1391
- elif is_ascend_npu():
1392
- torch.npu.empty_cache()
1393
- elif is_mlu():
1394
- torch.mlu.empty_cache()
1395
- elif torch.cuda.is_available():
1396
- torch.cuda.empty_cache()
1397
- torch.cuda.ipc_collect()
1398
-
1399
- def unload_all_models():
1400
- free_memory(1e30, get_torch_device())
1401
-
1402
-
1403
- #TODO: might be cleaner to put this somewhere else
1404
- import threading
1405
-
1406
- class InterruptProcessingException(Exception):
1407
- pass
1408
-
1409
- interrupt_processing_mutex = threading.RLock()
1410
-
1411
- interrupt_processing = False
1412
- def interrupt_current_processing(value=True):
1413
- global interrupt_processing
1414
- global interrupt_processing_mutex
1415
- with interrupt_processing_mutex:
1416
- interrupt_processing = value
1417
-
1418
- def processing_interrupted():
1419
- global interrupt_processing
1420
- global interrupt_processing_mutex
1421
- with interrupt_processing_mutex:
1422
- return interrupt_processing
1423
-
1424
- def throw_exception_if_processing_interrupted():
1425
- global interrupt_processing
1426
- global interrupt_processing_mutex
1427
- with interrupt_processing_mutex:
1428
- if interrupt_processing:
1429
- interrupt_processing = False
1430
- raise InterruptProcessingException()
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import psutil
20
+ import logging
21
+ from enum import Enum
22
+ from comfy.cli_args import args, PerformanceFeature
23
+ import torch
24
+ import sys
25
+ import importlib
26
+ import platform
27
+ import weakref
28
+ import gc
29
+
30
+ class VRAMState(Enum):
31
+ DISABLED = 0 #No vram present: no need to move models to vram
32
+ NO_VRAM = 1 #Very low vram: enable all the options to save vram
33
+ LOW_VRAM = 2
34
+ NORMAL_VRAM = 3
35
+ HIGH_VRAM = 4
36
+ SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
37
+
38
+ class CPUState(Enum):
39
+ GPU = 0
40
+ CPU = 1
41
+ MPS = 2
42
+
43
+ # Determine VRAM State
44
+ vram_state = VRAMState.NORMAL_VRAM
45
+ set_vram_to = VRAMState.NORMAL_VRAM
46
+ cpu_state = CPUState.GPU
47
+
48
+ total_vram = 0
49
+
50
+ def get_supported_float8_types():
51
+ float8_types = []
52
+ try:
53
+ float8_types.append(torch.float8_e4m3fn)
54
+ except:
55
+ pass
56
+ try:
57
+ float8_types.append(torch.float8_e4m3fnuz)
58
+ except:
59
+ pass
60
+ try:
61
+ float8_types.append(torch.float8_e5m2)
62
+ except:
63
+ pass
64
+ try:
65
+ float8_types.append(torch.float8_e5m2fnuz)
66
+ except:
67
+ pass
68
+ try:
69
+ float8_types.append(torch.float8_e8m0fnu)
70
+ except:
71
+ pass
72
+ return float8_types
73
+
74
+ FLOAT8_TYPES = get_supported_float8_types()
75
+
76
+ xpu_available = False
77
+ torch_version = ""
78
+ try:
79
+ torch_version = torch.version.__version__
80
+ temp = torch_version.split(".")
81
+ torch_version_numeric = (int(temp[0]), int(temp[1]))
82
+ except:
83
+ pass
84
+
85
+ lowvram_available = True
86
+ if args.deterministic:
87
+ logging.info("Using deterministic algorithms for pytorch")
88
+ torch.use_deterministic_algorithms(True, warn_only=True)
89
+
90
+ directml_enabled = False
91
+ if args.directml is not None:
92
+ import torch_directml
93
+ directml_enabled = True
94
+ device_index = args.directml
95
+ if device_index < 0:
96
+ directml_device = torch_directml.device()
97
+ else:
98
+ directml_device = torch_directml.device(device_index)
99
+ logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
100
+ # torch_directml.disable_tiled_resources(True)
101
+ lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
102
+
103
+ try:
104
+ import intel_extension_for_pytorch as ipex # noqa: F401
105
+ except:
106
+ pass
107
+
108
+ try:
109
+ _ = torch.xpu.device_count()
110
+ xpu_available = torch.xpu.is_available()
111
+ except:
112
+ xpu_available = False
113
+
114
+ try:
115
+ if torch.backends.mps.is_available():
116
+ cpu_state = CPUState.MPS
117
+ import torch.mps
118
+ except:
119
+ pass
120
+
121
+ try:
122
+ import torch_npu # noqa: F401
123
+ _ = torch.npu.device_count()
124
+ npu_available = torch.npu.is_available()
125
+ except:
126
+ npu_available = False
127
+
128
+ try:
129
+ import torch_mlu # noqa: F401
130
+ _ = torch.mlu.device_count()
131
+ mlu_available = torch.mlu.is_available()
132
+ except:
133
+ mlu_available = False
134
+
135
+ try:
136
+ ixuca_available = hasattr(torch, "corex")
137
+ except:
138
+ ixuca_available = False
139
+
140
+ if args.cpu:
141
+ cpu_state = CPUState.CPU
142
+
143
+ def is_intel_xpu():
144
+ global cpu_state
145
+ global xpu_available
146
+ if cpu_state == CPUState.GPU:
147
+ if xpu_available:
148
+ return True
149
+ return False
150
+
151
+ def is_ascend_npu():
152
+ global npu_available
153
+ if npu_available:
154
+ return True
155
+ return False
156
+
157
+ def is_mlu():
158
+ global mlu_available
159
+ if mlu_available:
160
+ return True
161
+ return False
162
+
163
+ def is_ixuca():
164
+ global ixuca_available
165
+ if ixuca_available:
166
+ return True
167
+ return False
168
+
169
+ def get_torch_device():
170
+ global directml_enabled
171
+ global cpu_state
172
+ if directml_enabled:
173
+ global directml_device
174
+ return directml_device
175
+ if cpu_state == CPUState.MPS:
176
+ return torch.device("mps")
177
+ if cpu_state == CPUState.CPU:
178
+ return torch.device("cpu")
179
+ else:
180
+ if torch.cuda.is_available(): # Добавьте эту проверку!
181
+ if is_intel_xpu():
182
+ return torch.device("xpu", torch.xpu.current_device())
183
+ elif is_ascend_npu():
184
+ return torch.device("npu", torch.npu.current_device())
185
+ elif is_mlu():
186
+ return torch.device("mlu", torch.mlu.current_device())
187
+ else:
188
+ return torch.device(torch.cuda.current_device())
189
+ else:
190
+ return torch.device("cpu") # Fallback на CPU, если CUDA недоступен
191
+
192
+ def get_total_memory(dev=None, torch_total_too=False):
193
+ global directml_enabled
194
+ if dev is None:
195
+ dev = get_torch_device()
196
+
197
+ if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
198
+ mem_total = psutil.virtual_memory().total
199
+ mem_total_torch = mem_total
200
+ else:
201
+ if directml_enabled:
202
+ mem_total = 1024 * 1024 * 1024 #TODO
203
+ mem_total_torch = mem_total
204
+ elif is_intel_xpu():
205
+ stats = torch.xpu.memory_stats(dev)
206
+ mem_reserved = stats['reserved_bytes.all.current']
207
+ mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
208
+ mem_total_torch = mem_reserved
209
+ mem_total = mem_total_xpu
210
+ elif is_ascend_npu():
211
+ stats = torch.npu.memory_stats(dev)
212
+ mem_reserved = stats['reserved_bytes.all.current']
213
+ _, mem_total_npu = torch.npu.mem_get_info(dev)
214
+ mem_total_torch = mem_reserved
215
+ mem_total = mem_total_npu
216
+ elif is_mlu():
217
+ stats = torch.mlu.memory_stats(dev)
218
+ mem_reserved = stats['reserved_bytes.all.current']
219
+ _, mem_total_mlu = torch.mlu.mem_get_info(dev)
220
+ mem_total_torch = mem_reserved
221
+ mem_total = mem_total_mlu
222
+ else:
223
+ stats = torch.cuda.memory_stats(dev)
224
+ mem_reserved = stats['reserved_bytes.all.current']
225
+ _, mem_total_cuda = torch.cuda.mem_get_info(dev)
226
+ mem_total_torch = mem_reserved
227
+ mem_total = mem_total_cuda
228
+
229
+ if torch_total_too:
230
+ return (mem_total, mem_total_torch)
231
+ else:
232
+ return mem_total
233
+
234
+ def mac_version():
235
+ try:
236
+ return tuple(int(n) for n in platform.mac_ver()[0].split("."))
237
+ except:
238
+ return None
239
+
240
+ total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
241
+ total_ram = psutil.virtual_memory().total / (1024 * 1024)
242
+ logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
243
+
244
+ try:
245
+ logging.info("pytorch version: {}".format(torch_version))
246
+ mac_ver = mac_version()
247
+ if mac_ver is not None:
248
+ logging.info("Mac Version {}".format(mac_ver))
249
+ except:
250
+ pass
251
+
252
+ try:
253
+ OOM_EXCEPTION = torch.cuda.OutOfMemoryError
254
+ except:
255
+ OOM_EXCEPTION = Exception
256
+
257
+ XFORMERS_VERSION = ""
258
+ XFORMERS_ENABLED_VAE = True
259
+ if args.disable_xformers:
260
+ XFORMERS_IS_AVAILABLE = False
261
+ else:
262
+ try:
263
+ import xformers
264
+ import xformers.ops
265
+ XFORMERS_IS_AVAILABLE = True
266
+ try:
267
+ XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
268
+ except:
269
+ pass
270
+ try:
271
+ XFORMERS_VERSION = xformers.version.__version__
272
+ logging.info("xformers version: {}".format(XFORMERS_VERSION))
273
+ if XFORMERS_VERSION.startswith("0.0.18"):
274
+ logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
275
+ logging.warning("Please downgrade or upgrade xformers to a different version.\n")
276
+ XFORMERS_ENABLED_VAE = False
277
+ except:
278
+ pass
279
+ except:
280
+ XFORMERS_IS_AVAILABLE = False
281
+
282
+ def is_nvidia():
283
+ global cpu_state
284
+ if cpu_state == CPUState.GPU:
285
+ if torch.version.cuda:
286
+ return True
287
+ return False
288
+
289
+ def is_amd():
290
+ global cpu_state
291
+ if cpu_state == CPUState.GPU:
292
+ if torch.version.hip:
293
+ return True
294
+ return False
295
+
296
+ def amd_min_version(device=None, min_rdna_version=0):
297
+ if not is_amd():
298
+ return False
299
+
300
+ if is_device_cpu(device):
301
+ return False
302
+
303
+ arch = torch.cuda.get_device_properties(device).gcnArchName
304
+ if arch.startswith('gfx') and len(arch) == 7:
305
+ try:
306
+ cmp_rdna_version = int(arch[4]) + 2
307
+ except:
308
+ cmp_rdna_version = 0
309
+ if cmp_rdna_version >= min_rdna_version:
310
+ return True
311
+
312
+ return False
313
+
314
+ MIN_WEIGHT_MEMORY_RATIO = 0.4
315
+ if is_nvidia():
316
+ MIN_WEIGHT_MEMORY_RATIO = 0.0
317
+
318
+ ENABLE_PYTORCH_ATTENTION = False
319
+ if args.use_pytorch_cross_attention:
320
+ ENABLE_PYTORCH_ATTENTION = True
321
+ XFORMERS_IS_AVAILABLE = False
322
+
323
+ try:
324
+ if is_nvidia():
325
+ if torch_version_numeric[0] >= 2:
326
+ if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
327
+ ENABLE_PYTORCH_ATTENTION = True
328
+ if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca():
329
+ if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
330
+ ENABLE_PYTORCH_ATTENTION = True
331
+ except:
332
+ pass
333
+
334
+
335
+ SUPPORT_FP8_OPS = args.supports_fp8_compute
336
+ try:
337
+ if is_amd():
338
+ try:
339
+ rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
340
+ except:
341
+ rocm_version = (6, -1)
342
+ arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
343
+ logging.info("AMD arch: {}".format(arch))
344
+ logging.info("ROCm version: {}".format(rocm_version))
345
+ if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
346
+ if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
347
+ if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
348
+ if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
349
+ ENABLE_PYTORCH_ATTENTION = True
350
+ # if torch_version_numeric >= (2, 8):
351
+ # if any((a in arch) for a in ["gfx1201"]):
352
+ # ENABLE_PYTORCH_ATTENTION = True
353
+ if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
354
+ if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
355
+ SUPPORT_FP8_OPS = True
356
+
357
+ except:
358
+ pass
359
+
360
+
361
+ if ENABLE_PYTORCH_ATTENTION:
362
+ torch.backends.cuda.enable_math_sdp(True)
363
+ torch.backends.cuda.enable_flash_sdp(True)
364
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
365
+
366
+
367
+ PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
368
+ try:
369
+ if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
370
+ torch.backends.cuda.matmul.allow_fp16_accumulation = True
371
+ PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
372
+ logging.info("Enabled fp16 accumulation.")
373
+ except:
374
+ pass
375
+
376
+ try:
377
+ if torch_version_numeric >= (2, 5):
378
+ torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
379
+ except:
380
+ logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
381
+
382
+ if args.lowvram:
383
+ set_vram_to = VRAMState.LOW_VRAM
384
+ lowvram_available = True
385
+ elif args.novram:
386
+ set_vram_to = VRAMState.NO_VRAM
387
+ elif args.highvram or args.gpu_only:
388
+ vram_state = VRAMState.HIGH_VRAM
389
+
390
+ FORCE_FP32 = False
391
+ if args.force_fp32:
392
+ logging.info("Forcing FP32, if this improves things please report it.")
393
+ FORCE_FP32 = True
394
+
395
+ if lowvram_available:
396
+ if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
397
+ vram_state = set_vram_to
398
+
399
+
400
+ if cpu_state != CPUState.GPU:
401
+ vram_state = VRAMState.DISABLED
402
+
403
+ if cpu_state == CPUState.MPS:
404
+ vram_state = VRAMState.SHARED
405
+
406
+ logging.info(f"Set vram state to: {vram_state.name}")
407
+
408
+ DISABLE_SMART_MEMORY = args.disable_smart_memory
409
+
410
+ if DISABLE_SMART_MEMORY:
411
+ logging.info("Disabling smart memory management")
412
+
413
+ def get_torch_device_name(device):
414
+ if hasattr(device, 'type'):
415
+ if device.type == "cuda":
416
+ try:
417
+ allocator_backend = torch.cuda.get_allocator_backend()
418
+ except:
419
+ allocator_backend = ""
420
+ return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
421
+ elif device.type == "xpu":
422
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
423
+ else:
424
+ return "{}".format(device.type)
425
+ elif is_intel_xpu():
426
+ return "{} {}".format(device, torch.xpu.get_device_name(device))
427
+ elif is_ascend_npu():
428
+ return "{} {}".format(device, torch.npu.get_device_name(device))
429
+ elif is_mlu():
430
+ return "{} {}".format(device, torch.mlu.get_device_name(device))
431
+ else:
432
+ return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
433
+
434
+ try:
435
+ logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
436
+ except:
437
+ logging.warning("Could not pick default device.")
438
+
439
+
440
+ current_loaded_models = []
441
+
442
+ def module_size(module):
443
+ module_mem = 0
444
+ sd = module.state_dict()
445
+ for k in sd:
446
+ t = sd[k]
447
+ module_mem += t.nelement() * t.element_size()
448
+ return module_mem
449
+
450
+ class LoadedModel:
451
+ def __init__(self, model):
452
+ self._set_model(model)
453
+ self.device = model.load_device
454
+ self.real_model = None
455
+ self.currently_used = True
456
+ self.model_finalizer = None
457
+ self._patcher_finalizer = None
458
+
459
+ def _set_model(self, model):
460
+ self._model = weakref.ref(model)
461
+ if model.parent is not None:
462
+ self._parent_model = weakref.ref(model.parent)
463
+ self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
464
+
465
+ def _switch_parent(self):
466
+ model = self._parent_model()
467
+ if model is not None:
468
+ self._set_model(model)
469
+
470
+ @property
471
+ def model(self):
472
+ return self._model()
473
+
474
+ def model_memory(self):
475
+ return self.model.model_size()
476
+
477
+ def model_loaded_memory(self):
478
+ return self.model.loaded_size()
479
+
480
+ def model_offloaded_memory(self):
481
+ return self.model.model_size() - self.model.loaded_size()
482
+
483
+ def model_memory_required(self, device):
484
+ if device == self.model.current_loaded_device():
485
+ return self.model_offloaded_memory()
486
+ else:
487
+ return self.model_memory()
488
+
489
+ def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
490
+ self.model.model_patches_to(self.device)
491
+ self.model.model_patches_to(self.model.model_dtype())
492
+
493
+ # if self.model.loaded_size() > 0:
494
+ use_more_vram = lowvram_model_memory
495
+ if use_more_vram == 0:
496
+ use_more_vram = 1e32
497
+ self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
498
+ real_model = self.model.model
499
+
500
+ if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
501
+ with torch.no_grad():
502
+ real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
503
+
504
+ self.real_model = weakref.ref(real_model)
505
+ self.model_finalizer = weakref.finalize(real_model, cleanup_models)
506
+ return real_model
507
+
508
+ def should_reload_model(self, force_patch_weights=False):
509
+ if force_patch_weights and self.model.lowvram_patch_counter() > 0:
510
+ return True
511
+ return False
512
+
513
+ def model_unload(self, memory_to_free=None, unpatch_weights=True):
514
+ if memory_to_free is not None:
515
+ if memory_to_free < self.model.loaded_size():
516
+ freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
517
+ if freed >= memory_to_free:
518
+ return False
519
+ self.model.detach(unpatch_weights)
520
+ self.model_finalizer.detach()
521
+ self.model_finalizer = None
522
+ self.real_model = None
523
+ return True
524
+
525
+ def model_use_more_vram(self, extra_memory, force_patch_weights=False):
526
+ return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
527
+
528
+ def __eq__(self, other):
529
+ return self.model is other.model
530
+
531
+ def __del__(self):
532
+ if self._patcher_finalizer is not None:
533
+ self._patcher_finalizer.detach()
534
+
535
+ def is_dead(self):
536
+ return self.real_model() is not None and self.model is None
537
+
538
+
539
+ def use_more_memory(extra_memory, loaded_models, device):
540
+ for m in loaded_models:
541
+ if m.device == device:
542
+ extra_memory -= m.model_use_more_vram(extra_memory)
543
+ if extra_memory <= 0:
544
+ break
545
+
546
+ def offloaded_memory(loaded_models, device):
547
+ offloaded_mem = 0
548
+ for m in loaded_models:
549
+ if m.device == device:
550
+ offloaded_mem += m.model_offloaded_memory()
551
+ return offloaded_mem
552
+
553
+ WINDOWS = any(platform.win32_ver())
554
+
555
+ EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
556
+ if WINDOWS:
557
+ EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
558
+ if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
559
+ EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
560
+
561
+ if args.reserve_vram is not None:
562
+ EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
563
+ logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
564
+
565
+ def extra_reserved_memory():
566
+ return EXTRA_RESERVED_VRAM
567
+
568
+ def minimum_inference_memory():
569
+ return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
570
+
571
+ def free_memory(memory_required, device, keep_loaded=[]):
572
+ cleanup_models_gc()
573
+ unloaded_model = []
574
+ can_unload = []
575
+ unloaded_models = []
576
+
577
+ for i in range(len(current_loaded_models) -1, -1, -1):
578
+ shift_model = current_loaded_models[i]
579
+ if shift_model.device == device:
580
+ if shift_model not in keep_loaded and not shift_model.is_dead():
581
+ can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
582
+ shift_model.currently_used = False
583
+
584
+ for x in sorted(can_unload):
585
+ i = x[-1]
586
+ memory_to_free = None
587
+ if not DISABLE_SMART_MEMORY:
588
+ free_mem = get_free_memory(device)
589
+ if free_mem > memory_required:
590
+ break
591
+ memory_to_free = memory_required - free_mem
592
+ logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
593
+ if current_loaded_models[i].model_unload(memory_to_free):
594
+ unloaded_model.append(i)
595
+
596
+ for i in sorted(unloaded_model, reverse=True):
597
+ unloaded_models.append(current_loaded_models.pop(i))
598
+
599
+ if len(unloaded_model) > 0:
600
+ soft_empty_cache()
601
+ else:
602
+ if vram_state != VRAMState.HIGH_VRAM:
603
+ mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
604
+ if mem_free_torch > mem_free_total * 0.25:
605
+ soft_empty_cache()
606
+ return unloaded_models
607
+
608
+ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
609
+ cleanup_models_gc()
610
+ global vram_state
611
+
612
+ inference_memory = minimum_inference_memory()
613
+ extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
614
+ if minimum_memory_required is None:
615
+ minimum_memory_required = extra_mem
616
+ else:
617
+ minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
618
+
619
+ models_temp = set()
620
+ for m in models:
621
+ models_temp.add(m)
622
+ for mm in m.model_patches_models():
623
+ models_temp.add(mm)
624
+
625
+ models = models_temp
626
+
627
+ models_to_load = []
628
+
629
+ for x in models:
630
+ loaded_model = LoadedModel(x)
631
+ try:
632
+ loaded_model_index = current_loaded_models.index(loaded_model)
633
+ except:
634
+ loaded_model_index = None
635
+
636
+ if loaded_model_index is not None:
637
+ loaded = current_loaded_models[loaded_model_index]
638
+ loaded.currently_used = True
639
+ models_to_load.append(loaded)
640
+ else:
641
+ if hasattr(x, "model"):
642
+ logging.info(f"Requested to load {x.model.__class__.__name__}")
643
+ models_to_load.append(loaded_model)
644
+
645
+ for loaded_model in models_to_load:
646
+ to_unload = []
647
+ for i in range(len(current_loaded_models)):
648
+ if loaded_model.model.is_clone(current_loaded_models[i].model):
649
+ to_unload = [i] + to_unload
650
+ for i in to_unload:
651
+ current_loaded_models.pop(i).model.detach(unpatch_all=False)
652
+
653
+ total_memory_required = {}
654
+ for loaded_model in models_to_load:
655
+ total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
656
+
657
+ for device in total_memory_required:
658
+ if device != torch.device("cpu"):
659
+ free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
660
+
661
+ for device in total_memory_required:
662
+ if device != torch.device("cpu"):
663
+ free_mem = get_free_memory(device)
664
+ if free_mem < minimum_memory_required:
665
+ models_l = free_memory(minimum_memory_required, device)
666
+ logging.info("{} models unloaded.".format(len(models_l)))
667
+
668
+ for loaded_model in models_to_load:
669
+ model = loaded_model.model
670
+ torch_dev = model.load_device
671
+ if is_device_cpu(torch_dev):
672
+ vram_set_state = VRAMState.DISABLED
673
+ else:
674
+ vram_set_state = vram_state
675
+ lowvram_model_memory = 0
676
+ if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
677
+ loaded_memory = loaded_model.model_loaded_memory()
678
+ current_free_mem = get_free_memory(torch_dev) + loaded_memory
679
+
680
+ lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
681
+ lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
682
+
683
+ if vram_set_state == VRAMState.NO_VRAM:
684
+ lowvram_model_memory = 0.1
685
+
686
+ loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
687
+ current_loaded_models.insert(0, loaded_model)
688
+ return
689
+
690
+ def load_model_gpu(model):
691
+ return load_models_gpu([model])
692
+
693
+ def loaded_models(only_currently_used=False):
694
+ output = []
695
+ for m in current_loaded_models:
696
+ if only_currently_used:
697
+ if not m.currently_used:
698
+ continue
699
+
700
+ output.append(m.model)
701
+ return output
702
+
703
+
704
+ def cleanup_models_gc():
705
+ do_gc = False
706
+ for i in range(len(current_loaded_models)):
707
+ cur = current_loaded_models[i]
708
+ if cur.is_dead():
709
+ logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
710
+ do_gc = True
711
+ break
712
+
713
+ if do_gc:
714
+ gc.collect()
715
+ soft_empty_cache()
716
+
717
+ for i in range(len(current_loaded_models)):
718
+ cur = current_loaded_models[i]
719
+ if cur.is_dead():
720
+ logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
721
+
722
+
723
+
724
+ def cleanup_models():
725
+ to_delete = []
726
+ for i in range(len(current_loaded_models)):
727
+ if current_loaded_models[i].real_model() is None:
728
+ to_delete = [i] + to_delete
729
+
730
+ for i in to_delete:
731
+ x = current_loaded_models.pop(i)
732
+ del x
733
+
734
+ def dtype_size(dtype):
735
+ dtype_size = 4
736
+ if dtype == torch.float16 or dtype == torch.bfloat16:
737
+ dtype_size = 2
738
+ elif dtype == torch.float32:
739
+ dtype_size = 4
740
+ else:
741
+ try:
742
+ dtype_size = dtype.itemsize
743
+ except: #Old pytorch doesn't have .itemsize
744
+ pass
745
+ return dtype_size
746
+
747
+ def unet_offload_device():
748
+ if vram_state == VRAMState.HIGH_VRAM:
749
+ return get_torch_device()
750
+ else:
751
+ return torch.device("cpu")
752
+
753
+ def unet_inital_load_device(parameters, dtype):
754
+ torch_dev = get_torch_device()
755
+ if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
756
+ return torch_dev
757
+
758
+ cpu_dev = torch.device("cpu")
759
+ if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
760
+ return cpu_dev
761
+
762
+ model_size = dtype_size(dtype) * parameters
763
+
764
+ mem_dev = get_free_memory(torch_dev)
765
+ mem_cpu = get_free_memory(cpu_dev)
766
+ if mem_dev > mem_cpu and model_size < mem_dev:
767
+ return torch_dev
768
+ else:
769
+ return cpu_dev
770
+
771
+ def maximum_vram_for_weights(device=None):
772
+ return (get_total_memory(device) * 0.88 - minimum_inference_memory())
773
+
774
+ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
775
+ if model_params < 0:
776
+ model_params = 1000000000000000000000
777
+ if args.fp32_unet:
778
+ return torch.float32
779
+ if args.fp64_unet:
780
+ return torch.float64
781
+ if args.bf16_unet:
782
+ return torch.bfloat16
783
+ if args.fp16_unet:
784
+ return torch.float16
785
+ if args.fp8_e4m3fn_unet:
786
+ return torch.float8_e4m3fn
787
+ if args.fp8_e5m2_unet:
788
+ return torch.float8_e5m2
789
+ if args.fp8_e8m0fnu_unet:
790
+ return torch.float8_e8m0fnu
791
+
792
+ fp8_dtype = None
793
+ if weight_dtype in FLOAT8_TYPES:
794
+ fp8_dtype = weight_dtype
795
+
796
+ if fp8_dtype is not None:
797
+ if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
798
+ return fp8_dtype
799
+
800
+ free_model_memory = maximum_vram_for_weights(device)
801
+ if model_params * 2 > free_model_memory:
802
+ return fp8_dtype
803
+
804
+ if PRIORITIZE_FP16 or weight_dtype == torch.float16:
805
+ if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
806
+ return torch.float16
807
+
808
+ for dt in supported_dtypes:
809
+ if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
810
+ if torch.float16 in supported_dtypes:
811
+ return torch.float16
812
+ if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
813
+ if torch.bfloat16 in supported_dtypes:
814
+ return torch.bfloat16
815
+
816
+ for dt in supported_dtypes:
817
+ if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
818
+ if torch.float16 in supported_dtypes:
819
+ return torch.float16
820
+ if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
821
+ if torch.bfloat16 in supported_dtypes:
822
+ return torch.bfloat16
823
+
824
+ return torch.float32
825
+
826
+ # None means no manual cast
827
+ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
828
+ if weight_dtype == torch.float32 or weight_dtype == torch.float64:
829
+ return None
830
+
831
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
832
+ if fp16_supported and weight_dtype == torch.float16:
833
+ return None
834
+
835
+ bf16_supported = should_use_bf16(inference_device)
836
+ if bf16_supported and weight_dtype == torch.bfloat16:
837
+ return None
838
+
839
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
840
+ if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
841
+ return torch.float16
842
+
843
+ for dt in supported_dtypes:
844
+ if dt == torch.float16 and fp16_supported:
845
+ return torch.float16
846
+ if dt == torch.bfloat16 and bf16_supported:
847
+ return torch.bfloat16
848
+
849
+ return torch.float32
850
+
851
+ def text_encoder_offload_device():
852
+ if args.gpu_only:
853
+ return get_torch_device()
854
+ else:
855
+ return torch.device("cpu")
856
+
857
+ def text_encoder_device():
858
+ if args.gpu_only:
859
+ return get_torch_device()
860
+ elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
861
+ if should_use_fp16(prioritize_performance=False):
862
+ return get_torch_device()
863
+ else:
864
+ return torch.device("cpu")
865
+ else:
866
+ return torch.device("cpu")
867
+
868
+ def text_encoder_initial_device(load_device, offload_device, model_size=0):
869
+ if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
870
+ return offload_device
871
+
872
+ if is_device_mps(load_device):
873
+ return load_device
874
+
875
+ mem_l = get_free_memory(load_device)
876
+ mem_o = get_free_memory(offload_device)
877
+ if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
878
+ return load_device
879
+ else:
880
+ return offload_device
881
+
882
+ def text_encoder_dtype(device=None):
883
+ if args.fp8_e4m3fn_text_enc:
884
+ return torch.float8_e4m3fn
885
+ elif args.fp8_e5m2_text_enc:
886
+ return torch.float8_e5m2
887
+ elif args.fp16_text_enc:
888
+ return torch.float16
889
+ elif args.bf16_text_enc:
890
+ return torch.bfloat16
891
+ elif args.fp32_text_enc:
892
+ return torch.float32
893
+
894
+ if is_device_cpu(device):
895
+ return torch.float16
896
+
897
+ return torch.float16
898
+
899
+
900
+ def intermediate_device():
901
+ if args.gpu_only:
902
+ return get_torch_device()
903
+ else:
904
+ return torch.device("cpu")
905
+
906
+ def vae_device():
907
+ if args.cpu_vae:
908
+ return torch.device("cpu")
909
+ return get_torch_device()
910
+
911
+ def vae_offload_device():
912
+ if args.gpu_only:
913
+ return get_torch_device()
914
+ else:
915
+ return torch.device("cpu")
916
+
917
+ def vae_dtype(device=None, allowed_dtypes=[]):
918
+ if args.fp16_vae:
919
+ return torch.float16
920
+ elif args.bf16_vae:
921
+ return torch.bfloat16
922
+ elif args.fp32_vae:
923
+ return torch.float32
924
+
925
+ for d in allowed_dtypes:
926
+ if d == torch.float16 and should_use_fp16(device):
927
+ return d
928
+
929
+ # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
930
+ # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
931
+ # also a problem on RDNA4 except fp32 is also slow there.
932
+ # This is due to large bf16 convolutions being extremely slow.
933
+ if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
934
+ return d
935
+
936
+ return torch.float32
937
+
938
+ def get_autocast_device(dev):
939
+ if hasattr(dev, 'type'):
940
+ return dev.type
941
+ return "cuda"
942
+
943
+ def supports_dtype(device, dtype): #TODO
944
+ if dtype == torch.float32:
945
+ return True
946
+ if is_device_cpu(device):
947
+ return False
948
+ if dtype == torch.float16:
949
+ return True
950
+ if dtype == torch.bfloat16:
951
+ return True
952
+ return False
953
+
954
+ def supports_cast(device, dtype): #TODO
955
+ if dtype == torch.float32:
956
+ return True
957
+ if dtype == torch.float16:
958
+ return True
959
+ if directml_enabled: #TODO: test this
960
+ return False
961
+ if dtype == torch.bfloat16:
962
+ return True
963
+ if is_device_mps(device):
964
+ return False
965
+ if dtype == torch.float8_e4m3fn:
966
+ return True
967
+ if dtype == torch.float8_e5m2:
968
+ return True
969
+ return False
970
+
971
+ def pick_weight_dtype(dtype, fallback_dtype, device=None):
972
+ if dtype is None:
973
+ dtype = fallback_dtype
974
+ elif dtype_size(dtype) > dtype_size(fallback_dtype):
975
+ dtype = fallback_dtype
976
+
977
+ if not supports_cast(device, dtype):
978
+ dtype = fallback_dtype
979
+
980
+ return dtype
981
+
982
+ def device_supports_non_blocking(device):
983
+ if args.force_non_blocking:
984
+ return True
985
+ if is_device_mps(device):
986
+ return False #pytorch bug? mps doesn't support non blocking
987
+ if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
988
+ return False
989
+ if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
990
+ return False
991
+ if directml_enabled:
992
+ return False
993
+ return True
994
+
995
+ def device_should_use_non_blocking(device):
996
+ if not device_supports_non_blocking(device):
997
+ return False
998
+ return False
999
+ # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
1000
+
1001
+ def force_channels_last():
1002
+ if args.force_channels_last:
1003
+ return True
1004
+
1005
+ #TODO
1006
+ return False
1007
+
1008
+
1009
+ STREAMS = {}
1010
+ NUM_STREAMS = 1
1011
+ if args.async_offload:
1012
+ NUM_STREAMS = 2
1013
+ logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
1014
+
1015
+ stream_counters = {}
1016
+ def get_offload_stream(device):
1017
+ stream_counter = stream_counters.get(device, 0)
1018
+ if NUM_STREAMS <= 1:
1019
+ return None
1020
+
1021
+ if device in STREAMS:
1022
+ ss = STREAMS[device]
1023
+ s = ss[stream_counter]
1024
+ stream_counter = (stream_counter + 1) % len(ss)
1025
+ if is_device_cuda(device):
1026
+ ss[stream_counter].wait_stream(torch.cuda.current_stream())
1027
+ elif is_device_xpu(device):
1028
+ ss[stream_counter].wait_stream(torch.xpu.current_stream())
1029
+ stream_counters[device] = stream_counter
1030
+ return s
1031
+ elif is_device_cuda(device):
1032
+ ss = []
1033
+ for k in range(NUM_STREAMS):
1034
+ ss.append(torch.cuda.Stream(device=device, priority=0))
1035
+ STREAMS[device] = ss
1036
+ s = ss[stream_counter]
1037
+ stream_counter = (stream_counter + 1) % len(ss)
1038
+ stream_counters[device] = stream_counter
1039
+ return s
1040
+ elif is_device_xpu(device):
1041
+ ss = []
1042
+ for k in range(NUM_STREAMS):
1043
+ ss.append(torch.xpu.Stream(device=device, priority=0))
1044
+ STREAMS[device] = ss
1045
+ s = ss[stream_counter]
1046
+ stream_counter = (stream_counter + 1) % len(ss)
1047
+ stream_counters[device] = stream_counter
1048
+ return s
1049
+ return None
1050
+
1051
+ def sync_stream(device, stream):
1052
+ if stream is None:
1053
+ return
1054
+ if is_device_cuda(device):
1055
+ torch.cuda.current_stream().wait_stream(stream)
1056
+ elif is_device_xpu(device):
1057
+ torch.xpu.current_stream().wait_stream(stream)
1058
+
1059
+ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
1060
+ if device is None or weight.device == device:
1061
+ if not copy:
1062
+ if dtype is None or weight.dtype == dtype:
1063
+ return weight
1064
+ if stream is not None:
1065
+ with stream:
1066
+ return weight.to(dtype=dtype, copy=copy)
1067
+ return weight.to(dtype=dtype, copy=copy)
1068
+
1069
+ if stream is not None:
1070
+ with stream:
1071
+ r = torch.empty_like(weight, dtype=dtype, device=device)
1072
+ r.copy_(weight, non_blocking=non_blocking)
1073
+ else:
1074
+ r = torch.empty_like(weight, dtype=dtype, device=device)
1075
+ r.copy_(weight, non_blocking=non_blocking)
1076
+ return r
1077
+
1078
+ def cast_to_device(tensor, device, dtype, copy=False):
1079
+ non_blocking = device_supports_non_blocking(device)
1080
+ return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
1081
+
1082
+ def sage_attention_enabled():
1083
+ return args.use_sage_attention
1084
+
1085
+ def flash_attention_enabled():
1086
+ return args.use_flash_attention
1087
+
1088
+ def xformers_enabled():
1089
+ global directml_enabled
1090
+ global cpu_state
1091
+ if cpu_state != CPUState.GPU:
1092
+ return False
1093
+ if is_intel_xpu():
1094
+ return False
1095
+ if is_ascend_npu():
1096
+ return False
1097
+ if is_mlu():
1098
+ return False
1099
+ if is_ixuca():
1100
+ return False
1101
+ if directml_enabled:
1102
+ return False
1103
+ return XFORMERS_IS_AVAILABLE
1104
+
1105
+
1106
+ def xformers_enabled_vae():
1107
+ enabled = xformers_enabled()
1108
+ if not enabled:
1109
+ return False
1110
+
1111
+ return XFORMERS_ENABLED_VAE
1112
+
1113
+ def pytorch_attention_enabled():
1114
+ global ENABLE_PYTORCH_ATTENTION
1115
+ return ENABLE_PYTORCH_ATTENTION
1116
+
1117
+ def pytorch_attention_enabled_vae():
1118
+ if is_amd():
1119
+ return False # enabling pytorch attention on AMD currently causes crash when doing high res
1120
+ return pytorch_attention_enabled()
1121
+
1122
+ def pytorch_attention_flash_attention():
1123
+ global ENABLE_PYTORCH_ATTENTION
1124
+ if ENABLE_PYTORCH_ATTENTION:
1125
+ #TODO: more reliable way of checking for flash attention?
1126
+ if is_nvidia():
1127
+ return True
1128
+ if is_intel_xpu():
1129
+ return True
1130
+ if is_ascend_npu():
1131
+ return True
1132
+ if is_mlu():
1133
+ return True
1134
+ if is_amd():
1135
+ return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
1136
+ if is_ixuca():
1137
+ return True
1138
+ return False
1139
+
1140
+ def force_upcast_attention_dtype():
1141
+ upcast = args.force_upcast_attention
1142
+
1143
+ macos_version = mac_version()
1144
+ if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed
1145
+ upcast = True
1146
+
1147
+ if upcast:
1148
+ return {torch.float16: torch.float32}
1149
+ else:
1150
+ return None
1151
+
1152
+ def get_free_memory(dev=None, torch_free_too=False):
1153
+ global directml_enabled
1154
+ if dev is None:
1155
+ dev = get_torch_device()
1156
+
1157
+ if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
1158
+ mem_free_total = psutil.virtual_memory().available
1159
+ mem_free_torch = mem_free_total
1160
+ else:
1161
+ if directml_enabled:
1162
+ mem_free_total = 1024 * 1024 * 1024 #TODO
1163
+ mem_free_torch = mem_free_total
1164
+ elif is_intel_xpu():
1165
+ stats = torch.xpu.memory_stats(dev)
1166
+ mem_active = stats['active_bytes.all.current']
1167
+ mem_reserved = stats['reserved_bytes.all.current']
1168
+ mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
1169
+ mem_free_torch = mem_reserved - mem_active
1170
+ mem_free_total = mem_free_xpu + mem_free_torch
1171
+ elif is_ascend_npu():
1172
+ stats = torch.npu.memory_stats(dev)
1173
+ mem_active = stats['active_bytes.all.current']
1174
+ mem_reserved = stats['reserved_bytes.all.current']
1175
+ mem_free_npu, _ = torch.npu.mem_get_info(dev)
1176
+ mem_free_torch = mem_reserved - mem_active
1177
+ mem_free_total = mem_free_npu + mem_free_torch
1178
+ elif is_mlu():
1179
+ stats = torch.mlu.memory_stats(dev)
1180
+ mem_active = stats['active_bytes.all.current']
1181
+ mem_reserved = stats['reserved_bytes.all.current']
1182
+ mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
1183
+ mem_free_torch = mem_reserved - mem_active
1184
+ mem_free_total = mem_free_mlu + mem_free_torch
1185
+ else:
1186
+ stats = torch.cuda.memory_stats(dev)
1187
+ mem_active = stats['active_bytes.all.current']
1188
+ mem_reserved = stats['reserved_bytes.all.current']
1189
+ mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
1190
+ mem_free_torch = mem_reserved - mem_active
1191
+ mem_free_total = mem_free_cuda + mem_free_torch
1192
+
1193
+ if torch_free_too:
1194
+ return (mem_free_total, mem_free_torch)
1195
+ else:
1196
+ return mem_free_total
1197
+
1198
+ def cpu_mode():
1199
+ global cpu_state
1200
+ return cpu_state == CPUState.CPU
1201
+
1202
+ def mps_mode():
1203
+ global cpu_state
1204
+ return cpu_state == CPUState.MPS
1205
+
1206
+ def is_device_type(device, type):
1207
+ if hasattr(device, 'type'):
1208
+ if (device.type == type):
1209
+ return True
1210
+ return False
1211
+
1212
+ def is_device_cpu(device):
1213
+ return is_device_type(device, 'cpu')
1214
+
1215
+ def is_device_mps(device):
1216
+ return is_device_type(device, 'mps')
1217
+
1218
+ def is_device_xpu(device):
1219
+ return is_device_type(device, 'xpu')
1220
+
1221
+ def is_device_cuda(device):
1222
+ return is_device_type(device, 'cuda')
1223
+
1224
+ def is_directml_enabled():
1225
+ global directml_enabled
1226
+ if directml_enabled:
1227
+ return True
1228
+
1229
+ return False
1230
+
1231
+ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1232
+ if device is not None:
1233
+ if is_device_cpu(device):
1234
+ return False
1235
+
1236
+ if args.force_fp16:
1237
+ return True
1238
+
1239
+ if FORCE_FP32:
1240
+ return False
1241
+
1242
+ if is_directml_enabled():
1243
+ return True
1244
+
1245
+ if (device is not None and is_device_mps(device)) or mps_mode():
1246
+ return True
1247
+
1248
+ if cpu_mode():
1249
+ return False
1250
+
1251
+ if is_intel_xpu():
1252
+ if torch_version_numeric < (2, 3):
1253
+ return True
1254
+ else:
1255
+ return torch.xpu.get_device_properties(device).has_fp16
1256
+
1257
+ if is_ascend_npu():
1258
+ return True
1259
+
1260
+ if is_mlu():
1261
+ return True
1262
+
1263
+ if is_ixuca():
1264
+ return True
1265
+
1266
+ if torch.version.hip:
1267
+ return True
1268
+
1269
+ props = torch.cuda.get_device_properties(device)
1270
+ if props.major >= 8:
1271
+ return True
1272
+
1273
+ if props.major < 6:
1274
+ return False
1275
+
1276
+ #FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
1277
+ nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
1278
+ for x in nvidia_10_series:
1279
+ if x in props.name.lower():
1280
+ if WINDOWS or manual_cast:
1281
+ return True
1282
+ else:
1283
+ return False #weird linux behavior where fp32 is faster
1284
+
1285
+ if manual_cast:
1286
+ free_model_memory = maximum_vram_for_weights(device)
1287
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
1288
+ return True
1289
+
1290
+ if props.major < 7:
1291
+ return False
1292
+
1293
+ #FP16 is just broken on these cards
1294
+ nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
1295
+ for x in nvidia_16_series:
1296
+ if x in props.name:
1297
+ return False
1298
+
1299
+ return True
1300
+
1301
+ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1302
+ if device is not None:
1303
+ if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
1304
+ return False
1305
+
1306
+ if FORCE_FP32:
1307
+ return False
1308
+
1309
+ if directml_enabled:
1310
+ return False
1311
+
1312
+ if (device is not None and is_device_mps(device)) or mps_mode():
1313
+ if mac_version() < (14,):
1314
+ return False
1315
+ return True
1316
+
1317
+ if cpu_mode():
1318
+ return False
1319
+
1320
+ if is_intel_xpu():
1321
+ if torch_version_numeric < (2, 3):
1322
+ return True
1323
+ else:
1324
+ return torch.xpu.is_bf16_supported()
1325
+
1326
+ if is_ascend_npu():
1327
+ return True
1328
+
1329
+ if is_ixuca():
1330
+ return True
1331
+
1332
+ if is_amd():
1333
+ arch = torch.cuda.get_device_properties(device).gcnArchName
1334
+ if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
1335
+ if manual_cast:
1336
+ return True
1337
+ return False
1338
+
1339
+ props = torch.cuda.get_device_properties(device)
1340
+
1341
+ if is_mlu():
1342
+ if props.major > 3:
1343
+ return True
1344
+
1345
+ if props.major >= 8:
1346
+ return True
1347
+
1348
+ bf16_works = torch.cuda.is_bf16_supported()
1349
+
1350
+ if bf16_works and manual_cast:
1351
+ free_model_memory = maximum_vram_for_weights(device)
1352
+ if (not prioritize_performance) or model_params * 4 > free_model_memory:
1353
+ return True
1354
+
1355
+ return False
1356
+
1357
+ def supports_fp8_compute(device=None):
1358
+ if SUPPORT_FP8_OPS:
1359
+ return True
1360
+
1361
+ if not is_nvidia():
1362
+ return False
1363
+
1364
+ props = torch.cuda.get_device_properties(device)
1365
+ if props.major >= 9:
1366
+ return True
1367
+ if props.major < 8:
1368
+ return False
1369
+ if props.minor < 9:
1370
+ return False
1371
+
1372
+ if torch_version_numeric < (2, 3):
1373
+ return False
1374
+
1375
+ if WINDOWS:
1376
+ if torch_version_numeric < (2, 4):
1377
+ return False
1378
+
1379
+ return True
1380
+
1381
+ def extended_fp16_support():
1382
+ # TODO: check why some models work with fp16 on newer torch versions but not on older
1383
+ if torch_version_numeric < (2, 7):
1384
+ return False
1385
+
1386
+ return True
1387
+
1388
+ def soft_empty_cache(force=False):
1389
+ global cpu_state
1390
+ if cpu_state == CPUState.MPS:
1391
+ torch.mps.empty_cache()
1392
+ elif is_intel_xpu():
1393
+ torch.xpu.empty_cache()
1394
+ elif is_ascend_npu():
1395
+ torch.npu.empty_cache()
1396
+ elif is_mlu():
1397
+ torch.mlu.empty_cache()
1398
+ elif torch.cuda.is_available():
1399
+ torch.cuda.empty_cache()
1400
+ torch.cuda.ipc_collect()
1401
+
1402
+ def unload_all_models():
1403
+ free_memory(1e30, get_torch_device())
1404
+
1405
+
1406
+ #TODO: might be cleaner to put this somewhere else
1407
+ import threading
1408
+
1409
+ class InterruptProcessingException(Exception):
1410
+ pass
1411
+
1412
+ interrupt_processing_mutex = threading.RLock()
1413
+
1414
+ interrupt_processing = False
1415
+ def interrupt_current_processing(value=True):
1416
+ global interrupt_processing
1417
+ global interrupt_processing_mutex
1418
+ with interrupt_processing_mutex:
1419
+ interrupt_processing = value
1420
+
1421
+ def processing_interrupted():
1422
+ global interrupt_processing
1423
+ global interrupt_processing_mutex
1424
+ with interrupt_processing_mutex:
1425
+ return interrupt_processing
1426
+
1427
+ def throw_exception_if_processing_interrupted():
1428
+ global interrupt_processing
1429
+ global interrupt_processing_mutex
1430
+ with interrupt_processing_mutex:
1431
+ if interrupt_processing:
1432
+ interrupt_processing = False
1433
+ raise InterruptProcessingException()