surena26 commited on
Commit
cca1a90
·
verified ·
1 Parent(s): d36e9c5

Upload ComfyUI/cuda_malloc.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ComfyUI/cuda_malloc.py +90 -0
ComfyUI/cuda_malloc.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ from comfy.cli_args import args
4
+ import subprocess
5
+
6
+ #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
7
+ def get_gpu_names():
8
+ if os.name == 'nt':
9
+ import ctypes
10
+
11
+ # Define necessary C structures and types
12
+ class DISPLAY_DEVICEA(ctypes.Structure):
13
+ _fields_ = [
14
+ ('cb', ctypes.c_ulong),
15
+ ('DeviceName', ctypes.c_char * 32),
16
+ ('DeviceString', ctypes.c_char * 128),
17
+ ('StateFlags', ctypes.c_ulong),
18
+ ('DeviceID', ctypes.c_char * 128),
19
+ ('DeviceKey', ctypes.c_char * 128)
20
+ ]
21
+
22
+ # Load user32.dll
23
+ user32 = ctypes.windll.user32
24
+
25
+ # Call EnumDisplayDevicesA
26
+ def enum_display_devices():
27
+ device_info = DISPLAY_DEVICEA()
28
+ device_info.cb = ctypes.sizeof(device_info)
29
+ device_index = 0
30
+ gpu_names = set()
31
+
32
+ while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0):
33
+ device_index += 1
34
+ gpu_names.add(device_info.DeviceString.decode('utf-8'))
35
+ return gpu_names
36
+ return enum_display_devices()
37
+ else:
38
+ gpu_names = set()
39
+ out = subprocess.check_output(['nvidia-smi', '-L'])
40
+ for l in out.split(b'\n'):
41
+ if len(l) > 0:
42
+ gpu_names.add(l.decode('utf-8').split(' (UUID')[0])
43
+ return gpu_names
44
+
45
+ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M",
46
+ "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620",
47
+ "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
48
+ "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
49
+ "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
50
+ "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
51
+ }
52
+
53
+ def cuda_malloc_supported():
54
+ try:
55
+ names = get_gpu_names()
56
+ except:
57
+ names = set()
58
+ for x in names:
59
+ if "NVIDIA" in x:
60
+ for b in blacklist:
61
+ if b in x:
62
+ return False
63
+ return True
64
+
65
+
66
+ if not args.cuda_malloc:
67
+ try:
68
+ version = ""
69
+ torch_spec = importlib.util.find_spec("torch")
70
+ for folder in torch_spec.submodule_search_locations:
71
+ ver_file = os.path.join(folder, "version.py")
72
+ if os.path.isfile(ver_file):
73
+ spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
74
+ module = importlib.util.module_from_spec(spec)
75
+ spec.loader.exec_module(module)
76
+ version = module.__version__
77
+ if int(version[0]) >= 2: #enable by default for torch version 2.0 and up
78
+ args.cuda_malloc = cuda_malloc_supported()
79
+ except:
80
+ pass
81
+
82
+
83
+ if args.cuda_malloc and not args.disable_cuda_malloc:
84
+ env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
85
+ if env_var is None:
86
+ env_var = "backend:cudaMallocAsync"
87
+ else:
88
+ env_var += ",backend:cudaMallocAsync"
89
+
90
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var