Spaces:
Running
on
Zero
Running
on
Zero
Upload 18 files
Browse files- spaces/__init__.py +30 -0
- spaces/config.py +37 -0
- spaces/gradio.py +55 -0
- spaces/utils.py +85 -0
- spaces/zero/__init__.py +21 -0
- spaces/zero/api.py +156 -0
- spaces/zero/client.py +239 -0
- spaces/zero/decorator.py +113 -0
- spaces/zero/gradio.py +150 -0
- spaces/zero/torch/__init__.py +42 -0
- spaces/zero/torch/bitsandbytes.py +162 -0
- spaces/zero/torch/packing.py +209 -0
- spaces/zero/torch/patching.py +386 -0
- spaces/zero/torch/patching_legacy.py +266 -0
- spaces/zero/torch/types.py +23 -0
- spaces/zero/tqdm.py +24 -0
- spaces/zero/types.py +49 -0
- spaces/zero/wrappers.py +418 -0
spaces/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
import sys
|
5 |
+
|
6 |
+
|
7 |
+
if sys.version_info.minor < 8: # pragma: no cover
|
8 |
+
raise RuntimeError("Importing PySpaces requires Python 3.8+")
|
9 |
+
|
10 |
+
|
11 |
+
# Prevent gradio from importing spaces
|
12 |
+
if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
|
13 |
+
try:
|
14 |
+
gr.Blocks
|
15 |
+
except AttributeError:
|
16 |
+
raise ImportError
|
17 |
+
|
18 |
+
|
19 |
+
from .zero.decorator import GPU
|
20 |
+
from .gradio import gradio_auto_wrap
|
21 |
+
from .gradio import disable_gradio_auto_wrap
|
22 |
+
from .gradio import enable_gradio_auto_wrap
|
23 |
+
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
'GPU',
|
27 |
+
'gradio_auto_wrap',
|
28 |
+
'disable_gradio_auto_wrap',
|
29 |
+
'enable_gradio_auto_wrap',
|
30 |
+
]
|
spaces/config.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from .utils import boolean
|
9 |
+
|
10 |
+
|
11 |
+
ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
|
12 |
+
|
13 |
+
|
14 |
+
class Settings:
|
15 |
+
def __init__(self):
|
16 |
+
self.zero_gpu = boolean(
|
17 |
+
os.getenv('SPACES_ZERO_GPU'))
|
18 |
+
self.zero_device_api_url = (
|
19 |
+
os.getenv('SPACES_ZERO_DEVICE_API_URL'))
|
20 |
+
self.gradio_auto_wrap = boolean(
|
21 |
+
os.getenv('SPACES_GRADIO_AUTO_WRAP'))
|
22 |
+
self.zero_patch_torch_device = boolean(
|
23 |
+
os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
|
24 |
+
self.zero_gpu_v2 = boolean(
|
25 |
+
os.getenv('ZEROGPU_V2'))
|
26 |
+
self.zerogpu_offload_dir = (
|
27 |
+
os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
|
28 |
+
|
29 |
+
|
30 |
+
Config = Settings()
|
31 |
+
|
32 |
+
|
33 |
+
if Config.zero_gpu:
|
34 |
+
assert Config.zero_device_api_url is not None, (
|
35 |
+
'SPACES_ZERO_DEVICE_API_URL env must be set '
|
36 |
+
'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
|
37 |
+
)
|
spaces/gradio.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import Callable
|
6 |
+
from typing import Generator
|
7 |
+
from typing import TypeVar
|
8 |
+
from typing import overload
|
9 |
+
from typing_extensions import ParamSpec
|
10 |
+
|
11 |
+
from .config import Config
|
12 |
+
from .zero.decorator import GPU
|
13 |
+
|
14 |
+
|
15 |
+
Param = ParamSpec('Param')
|
16 |
+
Res = TypeVar('Res')
|
17 |
+
|
18 |
+
|
19 |
+
gradio_auto_wrap_enabled = Config.gradio_auto_wrap
|
20 |
+
|
21 |
+
|
22 |
+
def disable_gradio_auto_wrap():
|
23 |
+
global gradio_auto_wrap_enabled
|
24 |
+
gradio_auto_wrap_enabled = False
|
25 |
+
|
26 |
+
def enable_gradio_auto_wrap():
|
27 |
+
global gradio_auto_wrap_enabled
|
28 |
+
gradio_auto_wrap_enabled = True
|
29 |
+
|
30 |
+
|
31 |
+
@overload
|
32 |
+
def gradio_auto_wrap(
|
33 |
+
task:
|
34 |
+
Callable[Param, Res],
|
35 |
+
) -> Callable[Param, Res]:
|
36 |
+
...
|
37 |
+
@overload
|
38 |
+
def gradio_auto_wrap(
|
39 |
+
task:
|
40 |
+
None,
|
41 |
+
) -> None:
|
42 |
+
...
|
43 |
+
def gradio_auto_wrap(
|
44 |
+
task:
|
45 |
+
Callable[Param, Res]
|
46 |
+
| None,
|
47 |
+
) -> (Callable[Param, Res]
|
48 |
+
| None):
|
49 |
+
"""
|
50 |
+
"""
|
51 |
+
if not gradio_auto_wrap_enabled:
|
52 |
+
return task
|
53 |
+
if not callable(task):
|
54 |
+
return task
|
55 |
+
return GPU(task) # type: ignore
|
spaces/utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import ctypes
|
6 |
+
import sys
|
7 |
+
from functools import lru_cache as cache
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import multiprocessing
|
11 |
+
from multiprocessing.queues import SimpleQueue as _SimpleQueue
|
12 |
+
from pathlib import Path
|
13 |
+
from pickle import PicklingError
|
14 |
+
from typing import Callable
|
15 |
+
from typing import TypeVar
|
16 |
+
|
17 |
+
|
18 |
+
GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
|
19 |
+
|
20 |
+
|
21 |
+
T = TypeVar('T')
|
22 |
+
|
23 |
+
|
24 |
+
@cache
|
25 |
+
def self_cgroup_device_path() -> str:
|
26 |
+
cgroup_content = Path('/proc/self/cgroup').read_text()
|
27 |
+
for line in cgroup_content.strip().split('\n'):
|
28 |
+
contents = line.split(':devices:')
|
29 |
+
if len(contents) != 2:
|
30 |
+
continue # pragma: no cover
|
31 |
+
return contents[1]
|
32 |
+
raise Exception # pragma: no cover
|
33 |
+
|
34 |
+
|
35 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
36 |
+
_SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
|
37 |
+
|
38 |
+
class SimpleQueue(_SimpleQueue[T]):
|
39 |
+
def __init__(self, *args):
|
40 |
+
super().__init__(*args, ctx=multiprocessing.get_context('fork'))
|
41 |
+
def put(self, obj: T):
|
42 |
+
try:
|
43 |
+
super().put(obj)
|
44 |
+
except PicklingError:
|
45 |
+
raise # pragma: no cover
|
46 |
+
# https://bugs.python.org/issue29187
|
47 |
+
except Exception as e:
|
48 |
+
message = str(e)
|
49 |
+
if not "pickle" in message:
|
50 |
+
raise # pragma: no cover
|
51 |
+
raise PicklingError(message)
|
52 |
+
def close(self): # Python 3.8 static typing trick
|
53 |
+
super().close() # type: ignore
|
54 |
+
def wlock_release(self):
|
55 |
+
if (lock := getattr(self, '_wlock', None)) is None:
|
56 |
+
return # pragma: no cover
|
57 |
+
try:
|
58 |
+
lock.release()
|
59 |
+
except ValueError:
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
|
64 |
+
def drop(*args):
|
65 |
+
return fn()
|
66 |
+
return drop
|
67 |
+
|
68 |
+
|
69 |
+
def boolean(value: str | None) -> bool:
|
70 |
+
return value is not None and value.lower() in ("1", "t", "true")
|
71 |
+
|
72 |
+
|
73 |
+
def gradio_request_var():
|
74 |
+
try:
|
75 |
+
from gradio.context import LocalContext
|
76 |
+
except ImportError: # pragma: no cover
|
77 |
+
raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
|
78 |
+
return LocalContext.request
|
79 |
+
|
80 |
+
|
81 |
+
def malloc_trim():
|
82 |
+
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
83 |
+
|
84 |
+
|
85 |
+
debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
|
spaces/zero/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from ..config import Config
|
7 |
+
|
8 |
+
|
9 |
+
if Config.zero_gpu:
|
10 |
+
|
11 |
+
from . import gradio
|
12 |
+
from . import torch
|
13 |
+
|
14 |
+
if torch.is_in_bad_fork():
|
15 |
+
raise RuntimeError(
|
16 |
+
"CUDA has been initialized before importing the `spaces` package"
|
17 |
+
)
|
18 |
+
|
19 |
+
torch.patch()
|
20 |
+
gradio.one_launch(torch.pack)
|
21 |
+
Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
|
spaces/zero/api.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Synced with huggingface/pyspaces:spaces/zero/api.py
|
3 |
+
"""
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from datetime import timedelta
|
7 |
+
from typing import Any
|
8 |
+
from typing import Generator
|
9 |
+
from typing import Literal
|
10 |
+
from typing import NamedTuple
|
11 |
+
from typing import Optional
|
12 |
+
from typing import overload
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
from pydantic import BaseModel
|
16 |
+
from typing_extensions import assert_never
|
17 |
+
|
18 |
+
|
19 |
+
AllowToken = str
|
20 |
+
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
|
21 |
+
NvidiaUUID = str
|
22 |
+
CGroupPath = str
|
23 |
+
VisitorId = str
|
24 |
+
Score = float
|
25 |
+
|
26 |
+
AuthLevel = Literal['regular', 'pro']
|
27 |
+
|
28 |
+
|
29 |
+
AUTHENTICATED_HEADER = 'X-Authenticated'
|
30 |
+
|
31 |
+
|
32 |
+
class ScheduleResponse(BaseModel):
|
33 |
+
idle: bool
|
34 |
+
nvidiaIndex: int
|
35 |
+
nvidiaUUID: str
|
36 |
+
allowToken: str
|
37 |
+
|
38 |
+
|
39 |
+
class QuotaInfos(BaseModel):
|
40 |
+
left: int
|
41 |
+
wait: timedelta
|
42 |
+
|
43 |
+
|
44 |
+
class ReportUsageMonitoringParams(NamedTuple):
|
45 |
+
nvidia_index: int
|
46 |
+
visitor_id: str
|
47 |
+
duration: timedelta
|
48 |
+
|
49 |
+
|
50 |
+
class QueueEvent(BaseModel):
|
51 |
+
event: Literal['ping', 'failed', 'succeeded']
|
52 |
+
data: Optional[ScheduleResponse] = None
|
53 |
+
|
54 |
+
|
55 |
+
def sse_parse(text: str):
|
56 |
+
event, *data = text.strip().splitlines()
|
57 |
+
assert event.startswith('event:')
|
58 |
+
event = event[6:].strip()
|
59 |
+
if event in ('ping', 'failed'):
|
60 |
+
return QueueEvent(event=event)
|
61 |
+
assert event == 'succeeded'
|
62 |
+
(data,) = data
|
63 |
+
assert data.startswith('data:')
|
64 |
+
data = data[5:].strip()
|
65 |
+
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
|
66 |
+
|
67 |
+
|
68 |
+
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
|
69 |
+
for text in res.iter_text():
|
70 |
+
if len(text) == 0:
|
71 |
+
break # pragma: no cover
|
72 |
+
try:
|
73 |
+
yield sse_parse(text)
|
74 |
+
except GeneratorExit:
|
75 |
+
res.close()
|
76 |
+
break
|
77 |
+
|
78 |
+
|
79 |
+
class APIClient:
|
80 |
+
|
81 |
+
def __init__(self, client: httpx.Client):
|
82 |
+
self.client = client
|
83 |
+
|
84 |
+
def startup_report(self) -> httpx.codes:
|
85 |
+
res = self.client.post('/startup-report')
|
86 |
+
return httpx.codes(res.status_code)
|
87 |
+
|
88 |
+
def schedule(
|
89 |
+
self,
|
90 |
+
cgroup_path: str,
|
91 |
+
task_id: int = 0,
|
92 |
+
token: str | None = None,
|
93 |
+
duration_seconds: int | None = None,
|
94 |
+
enable_queue: bool = True,
|
95 |
+
):
|
96 |
+
params: dict[str, str | int | bool] = {
|
97 |
+
'cgroupPath': cgroup_path,
|
98 |
+
'taskId': task_id,
|
99 |
+
'enableQueue': enable_queue,
|
100 |
+
}
|
101 |
+
if duration_seconds is not None:
|
102 |
+
params['durationSeconds'] = duration_seconds
|
103 |
+
if token is not None:
|
104 |
+
params['token'] = token
|
105 |
+
res = self.client.send(
|
106 |
+
request=self.client.build_request(
|
107 |
+
method='POST',
|
108 |
+
url='/schedule',
|
109 |
+
params=params,
|
110 |
+
),
|
111 |
+
stream=True,
|
112 |
+
)
|
113 |
+
status = httpx.codes(res.status_code)
|
114 |
+
auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
|
115 |
+
if (status is not httpx.codes.OK and
|
116 |
+
status is not httpx.codes.TOO_MANY_REQUESTS
|
117 |
+
):
|
118 |
+
res.close()
|
119 |
+
return status, auth
|
120 |
+
if "text/event-stream" in res.headers['content-type']:
|
121 |
+
return sse_stream(res), auth
|
122 |
+
res.read()
|
123 |
+
if status is httpx.codes.TOO_MANY_REQUESTS:
|
124 |
+
return QuotaInfos(**res.json()), auth # pragma: no cover
|
125 |
+
if status is httpx.codes.OK:
|
126 |
+
return ScheduleResponse(**res.json()), auth
|
127 |
+
assert_never(status)
|
128 |
+
|
129 |
+
def allow(
|
130 |
+
self,
|
131 |
+
allow_token: str,
|
132 |
+
pid: int,
|
133 |
+
):
|
134 |
+
res = self.client.post('/allow', params={
|
135 |
+
'allowToken': allow_token,
|
136 |
+
'pid': pid,
|
137 |
+
})
|
138 |
+
return httpx.codes(res.status_code)
|
139 |
+
|
140 |
+
def release(
|
141 |
+
self,
|
142 |
+
allow_token: str,
|
143 |
+
fail: bool = False,
|
144 |
+
) -> httpx.codes:
|
145 |
+
res = self.client.post('/release', params={
|
146 |
+
'allowToken': allow_token,
|
147 |
+
'fail': fail,
|
148 |
+
})
|
149 |
+
return httpx.codes(res.status_code)
|
150 |
+
|
151 |
+
def get_queue_size(self) -> int:
|
152 |
+
res = self.client.get('/queue-size')
|
153 |
+
assert res.status_code == 200, res.status_code
|
154 |
+
size = res.json()
|
155 |
+
assert isinstance(size, int)
|
156 |
+
return size
|
spaces/zero/client.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import httpx
|
12 |
+
from packaging import version
|
13 |
+
from typing_extensions import assert_never
|
14 |
+
|
15 |
+
from .. import utils
|
16 |
+
from ..config import Config
|
17 |
+
from .api import APIClient
|
18 |
+
from .api import AuthLevel
|
19 |
+
from .api import QuotaInfos
|
20 |
+
from .api import ScheduleResponse
|
21 |
+
from .gradio import HTMLError
|
22 |
+
from .gradio import get_event
|
23 |
+
from .gradio import supports_auth
|
24 |
+
|
25 |
+
|
26 |
+
TOKEN_HEADER = 'X-IP-Token'
|
27 |
+
DEFAULT_SCHEDULE_DURATION = 60
|
28 |
+
|
29 |
+
QUOTA_MESSAGE = "You have exceeded your GPU quota"
|
30 |
+
UNUSED_MESSAGE = "GPU device not used"
|
31 |
+
NO_GPU_MESSAGE_REGULAR = "No GPU was available"
|
32 |
+
NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
|
33 |
+
|
34 |
+
SIGNUP_ON_HF_TXT = "Create a free account"
|
35 |
+
SIGNUP_ON_HF_URL = "https://huggingface.co/join"
|
36 |
+
SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
|
37 |
+
SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
|
38 |
+
|
39 |
+
|
40 |
+
def api_client():
|
41 |
+
assert Config.zero_device_api_url is not None
|
42 |
+
httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
|
43 |
+
return APIClient(httpx_client)
|
44 |
+
|
45 |
+
|
46 |
+
def startup_report():
|
47 |
+
retries, max_retries = 0, 2
|
48 |
+
client = api_client()
|
49 |
+
while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
|
50 |
+
time.sleep(1)
|
51 |
+
if (retries := retries + 1) > max_retries:
|
52 |
+
raise RuntimeError("Error while initializing ZeroGPU: NotFound")
|
53 |
+
if status is not httpx.codes.OK: # pragma: no cover
|
54 |
+
raise RuntimeError("Error while initializing ZeroGPU: Unknown")
|
55 |
+
|
56 |
+
|
57 |
+
def html_string(html_contents: str, text_contents: str): # pragma: no cover
|
58 |
+
class HTMLString(str):
|
59 |
+
def __str__(self):
|
60 |
+
return text_contents
|
61 |
+
return HTMLString(html_contents)
|
62 |
+
|
63 |
+
|
64 |
+
def _toast_action(
|
65 |
+
auth: AuthLevel | None,
|
66 |
+
supports_html: bool,
|
67 |
+
pro_message: str,
|
68 |
+
unlogged_desc: str,
|
69 |
+
logged_desc: str,
|
70 |
+
ending: str,
|
71 |
+
) -> tuple[str, str]: # pragma: no cover
|
72 |
+
if not supports_auth() or auth == 'pro':
|
73 |
+
return pro_message, pro_message
|
74 |
+
html = ""
|
75 |
+
link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
|
76 |
+
text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
|
77 |
+
desc = unlogged_desc if auth is None else logged_desc
|
78 |
+
desc += f" {ending}."
|
79 |
+
style = ";".join([
|
80 |
+
"white-space: nowrap",
|
81 |
+
"text-underline-offset: 2px",
|
82 |
+
"color: var(--body-text-color)",
|
83 |
+
])
|
84 |
+
if supports_html:
|
85 |
+
html += f'<a style="{style}" href="{link}">'
|
86 |
+
html += text
|
87 |
+
if supports_html:
|
88 |
+
html += '</a> '
|
89 |
+
html += desc
|
90 |
+
markdown = f'[{text}]({link}) {desc}'
|
91 |
+
return html, markdown
|
92 |
+
|
93 |
+
|
94 |
+
def schedule(
|
95 |
+
task_id: int,
|
96 |
+
request: gr.Request | None = None,
|
97 |
+
duration: timedelta | None = None,
|
98 |
+
_first_attempt: bool = True,
|
99 |
+
) -> ScheduleResponse:
|
100 |
+
|
101 |
+
if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
|
102 |
+
raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
|
103 |
+
|
104 |
+
GRADIO_HTML_TOASTS = gradio_version.minor >= 39
|
105 |
+
|
106 |
+
res, auth = api_client().schedule(
|
107 |
+
cgroup_path=utils.self_cgroup_device_path(),
|
108 |
+
task_id=task_id,
|
109 |
+
token=_get_token(request),
|
110 |
+
duration_seconds=duration.seconds if duration is not None else None,
|
111 |
+
)
|
112 |
+
|
113 |
+
if isinstance(res, ScheduleResponse):
|
114 |
+
return res
|
115 |
+
|
116 |
+
if isinstance(res, QuotaInfos): # pragma: no cover
|
117 |
+
requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
|
118 |
+
if res.wait < timedelta(0):
|
119 |
+
raise gr.Error(
|
120 |
+
f"The requested GPU duration ({requested}s) "
|
121 |
+
f"is larger than the maximum allowed"
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
|
125 |
+
message = (
|
126 |
+
f"You have exceeded your {gpu} quota "
|
127 |
+
f"({requested}s requested vs. {res.left}s left)."
|
128 |
+
)
|
129 |
+
details_html, details_markdown = _toast_action(
|
130 |
+
auth=auth,
|
131 |
+
supports_html=GRADIO_HTML_TOASTS,
|
132 |
+
pro_message=f"Try again in {res.wait}",
|
133 |
+
unlogged_desc="to get more",
|
134 |
+
logged_desc="to get 5x more",
|
135 |
+
ending="usage quota",
|
136 |
+
)
|
137 |
+
message_html = f"{message} {details_html}"
|
138 |
+
message_text = f"{message} {details_markdown}"
|
139 |
+
raise HTMLError(html_string(message_html, message_text))
|
140 |
+
|
141 |
+
if not isinstance(res, httpx.codes): # pragma: no cover
|
142 |
+
gr.Info("Waiting for a GPU to become available")
|
143 |
+
# TODO: Sign-up message if not authenticated (after some time ?)
|
144 |
+
connection_event = get_event()
|
145 |
+
if connection_event is None and request is not None:
|
146 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
147 |
+
while True:
|
148 |
+
try:
|
149 |
+
event = next(res)
|
150 |
+
except StopIteration:
|
151 |
+
raise RuntimeError("Unexpected end of stream")
|
152 |
+
except httpx.RemoteProtocolError:
|
153 |
+
if not _first_attempt:
|
154 |
+
raise RuntimeError("Error while re-trying after queue disconnect")
|
155 |
+
return schedule(task_id, request, duration, _first_attempt=False)
|
156 |
+
if event.event == 'ping':
|
157 |
+
if connection_event is not None and not connection_event.alive:
|
158 |
+
res.close()
|
159 |
+
raise RuntimeError("Connection closed by visitor while queueing")
|
160 |
+
continue
|
161 |
+
if event.event == 'failed':
|
162 |
+
details_html, details_markdown = _toast_action(
|
163 |
+
auth=auth,
|
164 |
+
supports_html=GRADIO_HTML_TOASTS,
|
165 |
+
pro_message="Retry later",
|
166 |
+
unlogged_desc="to get a higher",
|
167 |
+
logged_desc="to get the highest",
|
168 |
+
ending="priority in ZeroGPU queues",
|
169 |
+
)
|
170 |
+
message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
|
171 |
+
message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
|
172 |
+
raise HTMLError(html_string(message_html, message_text))
|
173 |
+
if event.event == 'succeeded':
|
174 |
+
assert event.data is not None
|
175 |
+
if connection_event is not None and not connection_event.alive:
|
176 |
+
release(event.data.allowToken)
|
177 |
+
raise RuntimeError("Connection closed by visitor on queue success")
|
178 |
+
gr.Info("Successfully acquired a GPU")
|
179 |
+
return event.data
|
180 |
+
|
181 |
+
if res is httpx.codes.SERVICE_UNAVAILABLE:
|
182 |
+
raise gr.Error(NO_GPU_MESSAGE_REGULAR)
|
183 |
+
|
184 |
+
# TODO: Find a way to log 'detail' response field
|
185 |
+
raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
186 |
+
|
187 |
+
|
188 |
+
def allow(allow_token: str) -> None:
|
189 |
+
pid = os.getpid()
|
190 |
+
assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
|
191 |
+
assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
|
192 |
+
|
193 |
+
|
194 |
+
def release(
|
195 |
+
allow_token: str, *,
|
196 |
+
fail: bool = False,
|
197 |
+
allow_404: bool = False,
|
198 |
+
) -> None:
|
199 |
+
|
200 |
+
res = api_client().release(
|
201 |
+
allow_token=allow_token,
|
202 |
+
fail=fail,
|
203 |
+
)
|
204 |
+
|
205 |
+
if res is httpx.codes.NO_CONTENT: # pragma: no cover
|
206 |
+
try:
|
207 |
+
gr.Warning(UNUSED_MESSAGE)
|
208 |
+
except AttributeError:
|
209 |
+
pass
|
210 |
+
warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
|
211 |
+
return None
|
212 |
+
|
213 |
+
if res is httpx.codes.NOT_FOUND:
|
214 |
+
if not allow_404:
|
215 |
+
warnings.warn("ZeroGPU API /release warning: 404 Not Found")
|
216 |
+
return None
|
217 |
+
|
218 |
+
if httpx.codes.is_success(res):
|
219 |
+
return None
|
220 |
+
|
221 |
+
# TODO: Find a way to log 'detail' response field
|
222 |
+
# TODO: Only raise in dev environment. Simply warn in production ?
|
223 |
+
raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
224 |
+
|
225 |
+
|
226 |
+
def _get_token(request: gr.Request | None) -> str | None:
|
227 |
+
|
228 |
+
if request is None:
|
229 |
+
return None
|
230 |
+
|
231 |
+
headers = getattr(request, 'headers', None)
|
232 |
+
if headers is None or not hasattr(headers, '__dict__'):
|
233 |
+
raise gr.Error("Internal Gradio error")
|
234 |
+
|
235 |
+
# Compatibility trick
|
236 |
+
if not hasattr(headers, 'get'):
|
237 |
+
headers = headers.__dict__ # pragma: no cover
|
238 |
+
|
239 |
+
return headers.get(TOKEN_HEADER.lower())
|
spaces/zero/decorator.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import inspect
|
6 |
+
import sys
|
7 |
+
import warnings
|
8 |
+
from datetime import timedelta
|
9 |
+
from functools import partial
|
10 |
+
from typing import Callable
|
11 |
+
from typing import TypeVar
|
12 |
+
from typing import overload
|
13 |
+
from typing_extensions import ParamSpec
|
14 |
+
from typing_extensions import Unpack
|
15 |
+
|
16 |
+
from ..config import Config
|
17 |
+
from .types import DynamicDuration
|
18 |
+
from .types import EmptyKwargs
|
19 |
+
|
20 |
+
|
21 |
+
P = ParamSpec('P')
|
22 |
+
R = TypeVar('R')
|
23 |
+
|
24 |
+
|
25 |
+
decorated_cache: dict[Callable, Callable] = {}
|
26 |
+
|
27 |
+
|
28 |
+
@overload
|
29 |
+
def GPU(
|
30 |
+
task: None = None, *,
|
31 |
+
duration: DynamicDuration[P] = None,
|
32 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
33 |
+
...
|
34 |
+
@overload
|
35 |
+
def GPU(
|
36 |
+
task: Callable[P, R], *,
|
37 |
+
duration: DynamicDuration[P] = None,
|
38 |
+
) -> Callable[P, R]:
|
39 |
+
...
|
40 |
+
def GPU(
|
41 |
+
task: Callable[P, R] | None = None, *,
|
42 |
+
duration: DynamicDuration[P] = None,
|
43 |
+
**kwargs: Unpack[EmptyKwargs],
|
44 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
|
45 |
+
"""
|
46 |
+
ZeroGPU decorator
|
47 |
+
|
48 |
+
Basic usage:
|
49 |
+
```
|
50 |
+
@spaces.GPU
|
51 |
+
def fn(...):
|
52 |
+
# CUDA is available here
|
53 |
+
pass
|
54 |
+
```
|
55 |
+
|
56 |
+
With custom duration:
|
57 |
+
```
|
58 |
+
@spaces.GPU(duration=45) # Expressed in seconds
|
59 |
+
def fn(...):
|
60 |
+
# CUDA is available here
|
61 |
+
pass
|
62 |
+
```
|
63 |
+
|
64 |
+
Args:
|
65 |
+
task (`Callable | None`): Python function that requires CUDA
|
66 |
+
duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
`Callable`: GPU-ready function
|
70 |
+
"""
|
71 |
+
if "enable_queue" in kwargs:
|
72 |
+
warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
|
73 |
+
if task is None:
|
74 |
+
return partial(_GPU, duration=duration)
|
75 |
+
return _GPU(task, duration)
|
76 |
+
|
77 |
+
|
78 |
+
def _GPU(
|
79 |
+
task: Callable[P, R],
|
80 |
+
duration: DynamicDuration[P],
|
81 |
+
) -> Callable[P, R]:
|
82 |
+
|
83 |
+
if not Config.zero_gpu:
|
84 |
+
return task
|
85 |
+
|
86 |
+
from . import client
|
87 |
+
from .wrappers import regular_function_wrapper
|
88 |
+
from .wrappers import generator_function_wrapper
|
89 |
+
|
90 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
91 |
+
raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
|
92 |
+
|
93 |
+
if task in decorated_cache:
|
94 |
+
# TODO: Assert same duration ?
|
95 |
+
return decorated_cache[task] # type: ignore
|
96 |
+
|
97 |
+
if inspect.iscoroutinefunction(task):
|
98 |
+
raise NotImplementedError
|
99 |
+
|
100 |
+
if inspect.isgeneratorfunction(task):
|
101 |
+
decorated = generator_function_wrapper(task, duration)
|
102 |
+
else:
|
103 |
+
decorated = regular_function_wrapper(task, duration)
|
104 |
+
|
105 |
+
setattr(decorated, 'zerogpu', None)
|
106 |
+
|
107 |
+
client.startup_report()
|
108 |
+
decorated_cache.update({
|
109 |
+
task: decorated,
|
110 |
+
decorated: decorated,
|
111 |
+
})
|
112 |
+
|
113 |
+
return decorated # type: ignore
|
spaces/zero/gradio.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from functools import wraps
|
6 |
+
from packaging import version
|
7 |
+
from typing import Callable
|
8 |
+
from typing import NamedTuple
|
9 |
+
from typing import TYPE_CHECKING
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
from gradio.context import Context
|
14 |
+
from gradio.context import LocalContext
|
15 |
+
from gradio.helpers import Progress
|
16 |
+
from gradio.helpers import TrackedIterable
|
17 |
+
from gradio.queueing import Queue
|
18 |
+
from typing_extensions import ParamSpec
|
19 |
+
|
20 |
+
from ..utils import SimpleQueue
|
21 |
+
from .types import GeneratorResQueueResult
|
22 |
+
from .types import GradioQueueEvent
|
23 |
+
from .types import RegularResQueueResult
|
24 |
+
|
25 |
+
|
26 |
+
QUEUE_RPC_METHODS = [
|
27 |
+
"set_progress",
|
28 |
+
"log_message",
|
29 |
+
]
|
30 |
+
|
31 |
+
|
32 |
+
class GradioPartialContext(NamedTuple):
|
33 |
+
event_id: str | None
|
34 |
+
in_event_listener: bool
|
35 |
+
progress: Progress | None
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def get():
|
39 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
40 |
+
return GradioPartialContext(
|
41 |
+
event_id=LocalContext.event_id.get(),
|
42 |
+
in_event_listener=LocalContext.in_event_listener.get(),
|
43 |
+
progress=LocalContext.progress.get(),
|
44 |
+
)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def apply(context: 'GradioPartialContext'):
|
48 |
+
LocalContext.event_id.set(context.event_id)
|
49 |
+
LocalContext.in_event_listener.set(context.in_event_listener)
|
50 |
+
LocalContext.progress.set(context.progress)
|
51 |
+
|
52 |
+
|
53 |
+
def get_queue_instance():
|
54 |
+
blocks = LocalContext.blocks.get()
|
55 |
+
if blocks is None: # pragma: no cover
|
56 |
+
return None
|
57 |
+
return blocks._queue
|
58 |
+
|
59 |
+
|
60 |
+
def get_event():
|
61 |
+
queue = get_queue_instance()
|
62 |
+
event_id = LocalContext.event_id.get()
|
63 |
+
if queue is None:
|
64 |
+
return None
|
65 |
+
if event_id is None: # pragma: no cover
|
66 |
+
return None
|
67 |
+
for job in queue.active_jobs:
|
68 |
+
if job is None: # pragma: no cover
|
69 |
+
continue
|
70 |
+
for event in job:
|
71 |
+
if event._id == event_id:
|
72 |
+
return event
|
73 |
+
|
74 |
+
|
75 |
+
def get_server_port() -> int | None:
|
76 |
+
from_request_context = True
|
77 |
+
if (blocks := LocalContext.blocks.get()) is None: # Request
|
78 |
+
from_request_context = False
|
79 |
+
if (blocks := Context.root_block) is None: # Caching
|
80 |
+
return None
|
81 |
+
if (server := getattr(blocks, 'server', None)) is None:
|
82 |
+
if from_request_context:
|
83 |
+
warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
|
84 |
+
return -1
|
85 |
+
if TYPE_CHECKING:
|
86 |
+
assert (server := blocks.server)
|
87 |
+
return server.config.port
|
88 |
+
|
89 |
+
|
90 |
+
def try_process_queue_event(method_name: str, *args, **kwargs):
|
91 |
+
queue = get_queue_instance()
|
92 |
+
if queue is None: # pragma: no cover
|
93 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
94 |
+
return
|
95 |
+
method = getattr(queue, method_name, None)
|
96 |
+
assert callable(method)
|
97 |
+
method(*args, **kwargs)
|
98 |
+
|
99 |
+
|
100 |
+
def patch_gradio_queue(
|
101 |
+
res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
|
102 |
+
):
|
103 |
+
|
104 |
+
def rpc_method(method_name: str):
|
105 |
+
def method(*args, **kwargs):
|
106 |
+
if args and isinstance(args[0], Queue):
|
107 |
+
args = args[1:] # drop `self`
|
108 |
+
res_queue.put(GradioQueueEvent(method_name, args, kwargs))
|
109 |
+
return method
|
110 |
+
|
111 |
+
for method_name in QUEUE_RPC_METHODS:
|
112 |
+
if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
|
113 |
+
warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
|
114 |
+
continue
|
115 |
+
if not callable(method): # pragma: no cover
|
116 |
+
warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
|
117 |
+
continue
|
118 |
+
setattr(Queue, method_name, rpc_method(method_name))
|
119 |
+
|
120 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
121 |
+
|
122 |
+
|
123 |
+
def tracked_iterable__reduce__(self):
|
124 |
+
res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
|
125 |
+
cls, base, state, *_ = res
|
126 |
+
return cls, base,{**state, **{
|
127 |
+
'iterable': None,
|
128 |
+
'_tqdm': None,
|
129 |
+
}}
|
130 |
+
|
131 |
+
|
132 |
+
def supports_auth():
|
133 |
+
return version.parse(gr.__version__) >= version.Version('4.27.0')
|
134 |
+
|
135 |
+
|
136 |
+
Param = ParamSpec('Param')
|
137 |
+
|
138 |
+
def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
|
139 |
+
_launch = gr.Blocks.launch
|
140 |
+
@wraps(gr.Blocks.launch)
|
141 |
+
def launch(*args, **kwargs):
|
142 |
+
task(*task_args, **task_kwargs)
|
143 |
+
gr.Blocks.launch = _launch
|
144 |
+
return gr.Blocks.launch(*args, **kwargs)
|
145 |
+
gr.Blocks.launch = launch
|
146 |
+
|
147 |
+
|
148 |
+
class HTMLError(gr.Error):
|
149 |
+
def __str__(self): # pragma: no cover
|
150 |
+
return self.message
|
spaces/zero/torch/__init__.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from ...config import Config
|
5 |
+
|
6 |
+
|
7 |
+
try:
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
except ImportError:
|
12 |
+
|
13 |
+
_patch = lambda *args, **kwargs: None
|
14 |
+
_unpatch = lambda *args, **kwargs: None
|
15 |
+
_pack = lambda *args, **kwargs: None
|
16 |
+
_init = lambda *args, **kwargs: None
|
17 |
+
_size = lambda *args, **kwargs: 0
|
18 |
+
_move = lambda *args, **kwargs: None
|
19 |
+
_is_in_bad_fork = lambda *args, **kwargs: False
|
20 |
+
|
21 |
+
else:
|
22 |
+
|
23 |
+
if Config.zero_gpu_v2:
|
24 |
+
from . import patching as _patching
|
25 |
+
else: # pragma: no cover
|
26 |
+
from . import patching_legacy as _patching
|
27 |
+
|
28 |
+
_patch = _patching.patch
|
29 |
+
_unpatch = _patching.unpatch
|
30 |
+
_pack = _patching.pack
|
31 |
+
_init = _patching.init
|
32 |
+
_size = _patching.size
|
33 |
+
_move = _patching.move
|
34 |
+
_is_in_bad_fork = _patching.is_in_bad_fork
|
35 |
+
|
36 |
+
patch = _patch
|
37 |
+
unpatch = _unpatch
|
38 |
+
pack = _pack
|
39 |
+
init = _init
|
40 |
+
size = _size
|
41 |
+
move = _move
|
42 |
+
is_in_bad_fork = _is_in_bad_fork
|
spaces/zero/torch/bitsandbytes.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
# pyright: reportPrivateImportUsage=false
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from importlib import metadata
|
10 |
+
from types import ModuleType
|
11 |
+
from typing import TYPE_CHECKING
|
12 |
+
from typing import Tuple
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from packaging import version
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
import torch as Torch
|
19 |
+
|
20 |
+
|
21 |
+
@contextmanager
|
22 |
+
def cuda_unavailable(torch: ModuleType):
|
23 |
+
_is_available = torch.cuda.is_available
|
24 |
+
torch.cuda.is_available = lambda: False
|
25 |
+
yield
|
26 |
+
torch.cuda.is_available = _is_available
|
27 |
+
|
28 |
+
|
29 |
+
def maybe_import_bitsandbytes():
|
30 |
+
try:
|
31 |
+
import torch
|
32 |
+
except ImportError: # pragma: no cover
|
33 |
+
return None
|
34 |
+
with cuda_unavailable(torch):
|
35 |
+
try:
|
36 |
+
import bitsandbytes
|
37 |
+
except ImportError:
|
38 |
+
bitsandbytes = None
|
39 |
+
else:
|
40 |
+
if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
|
41 |
+
raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
|
42 |
+
print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
|
43 |
+
return bitsandbytes
|
44 |
+
|
45 |
+
|
46 |
+
if (bnb := maybe_import_bitsandbytes()):
|
47 |
+
|
48 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
49 |
+
|
50 |
+
with cuda_unavailable(torch):
|
51 |
+
from bitsandbytes import cextension
|
52 |
+
from bitsandbytes import functional
|
53 |
+
try: # bitsandbytes < 0.44
|
54 |
+
from bitsandbytes.cuda_setup.main import CUDASetup
|
55 |
+
except ModuleNotFoundError: # pragma: no cover
|
56 |
+
CUDASetup = None
|
57 |
+
from bitsandbytes.nn import Int8Params
|
58 |
+
from bitsandbytes.nn import Params4bit
|
59 |
+
|
60 |
+
_param_to_8bit = Int8Params.to # type: ignore
|
61 |
+
_param_cuda_8bit = Int8Params.cuda
|
62 |
+
_param_to_4bit = Params4bit.to # type: ignore
|
63 |
+
_param_cuda_4bit = Params4bit.cuda
|
64 |
+
|
65 |
+
TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
|
66 |
+
|
67 |
+
to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
68 |
+
to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
69 |
+
|
70 |
+
def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
|
71 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
72 |
+
device, *_ = parsed
|
73 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
74 |
+
return _param_to_8bit(self, *args, **kwargs)
|
75 |
+
if device.type != 'cuda':
|
76 |
+
return _param_to_8bit(self, *args, **kwargs)
|
77 |
+
to_ops_8bit[self] = parsed
|
78 |
+
return self
|
79 |
+
|
80 |
+
def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
|
81 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
82 |
+
device, *_ = parsed
|
83 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
84 |
+
return _param_to_4bit(self, *args, **kwargs)
|
85 |
+
if device.type != 'cuda':
|
86 |
+
return _param_to_4bit(self, *args, **kwargs)
|
87 |
+
to_ops_4bit[self] = parsed
|
88 |
+
return self
|
89 |
+
|
90 |
+
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
91 |
+
if device is None: # pragma: no cover
|
92 |
+
return True
|
93 |
+
if isinstance(device, int):
|
94 |
+
return True
|
95 |
+
if isinstance(device, str): # pragma: no cover
|
96 |
+
device = torch.device(device)
|
97 |
+
return device.type == 'cuda' # pragma: no cover
|
98 |
+
|
99 |
+
def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
|
100 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
101 |
+
# Let PyTorch handle the fail
|
102 |
+
return _param_cuda_8bit(self, device, **kwargs)
|
103 |
+
to_ops_8bit[self] = None
|
104 |
+
return self
|
105 |
+
|
106 |
+
def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
|
107 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
108 |
+
# Let PyTorch handle the fail
|
109 |
+
return _param_cuda_4bit(self, device, **kwargs)
|
110 |
+
to_ops_4bit[self] = None
|
111 |
+
return self
|
112 |
+
|
113 |
+
def _patch():
|
114 |
+
Int8Params.to = _to_op_register_8bit # type: ignore
|
115 |
+
Int8Params.cuda = _cuda_op_register_8bit # type: ignore
|
116 |
+
Params4bit.to = _to_op_register_4bit # type: ignore
|
117 |
+
Params4bit.cuda = _cuda_op_register_4bit # type: ignore
|
118 |
+
|
119 |
+
def _unpatch():
|
120 |
+
Int8Params.to = _param_to_8bit # type: ignore
|
121 |
+
Int8Params.cuda = _param_cuda_8bit
|
122 |
+
Params4bit.to = _param_to_4bit # type: ignore
|
123 |
+
Params4bit.cuda = _param_cuda_4bit
|
124 |
+
|
125 |
+
def _move():
|
126 |
+
if CUDASetup is not None:
|
127 |
+
CUDASetup._instance = None
|
128 |
+
importlib.reload(cextension)
|
129 |
+
functional.lib = cextension.lib
|
130 |
+
for op in to_ops_8bit.items():
|
131 |
+
tensor, parsed_args = op
|
132 |
+
if parsed_args:
|
133 |
+
_, dtype, _, memory_format = parsed_args
|
134 |
+
else:
|
135 |
+
dtype, memory_format = None, None
|
136 |
+
tensor.data = _param_to_8bit(tensor,
|
137 |
+
device='cuda',
|
138 |
+
dtype=dtype,
|
139 |
+
memory_format=memory_format,
|
140 |
+
) # type: ignore
|
141 |
+
for op in to_ops_4bit.items():
|
142 |
+
tensor, parsed_args = op
|
143 |
+
if parsed_args:
|
144 |
+
_, dtype, _, memory_format = parsed_args
|
145 |
+
else:
|
146 |
+
dtype, memory_format = None, None
|
147 |
+
tensor.data = _param_to_4bit(tensor,
|
148 |
+
device='cuda',
|
149 |
+
dtype=dtype,
|
150 |
+
memory_format=memory_format,
|
151 |
+
) # type: ignore
|
152 |
+
|
153 |
+
else:
|
154 |
+
|
155 |
+
_patch = lambda: None
|
156 |
+
_unpatch = lambda: None
|
157 |
+
_move = lambda: None
|
158 |
+
|
159 |
+
|
160 |
+
patch = _patch
|
161 |
+
unpatch = _unpatch
|
162 |
+
move = _move
|
spaces/zero/torch/packing.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import time
|
6 |
+
|
7 |
+
import ctypes
|
8 |
+
import os
|
9 |
+
from concurrent.futures import as_completed
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
from contextvars import copy_context
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from queue import Queue
|
14 |
+
from typing import Callable
|
15 |
+
|
16 |
+
from ...utils import debug
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from typing_extensions import TypeAlias
|
20 |
+
|
21 |
+
|
22 |
+
PAGE_SIZE = 4096
|
23 |
+
TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
24 |
+
VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
|
25 |
+
|
26 |
+
BUFFER_SIZE = 64 * 2**20
|
27 |
+
BUFFER_COUNT = 2
|
28 |
+
|
29 |
+
|
30 |
+
TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class ZeroGPUTensorPack:
|
34 |
+
base_dir: str
|
35 |
+
batches: list[list[TensorWithSizes]]
|
36 |
+
big_tensors: list[TensorWithSizes]
|
37 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]]
|
38 |
+
total_size: int
|
39 |
+
def path(self):
|
40 |
+
return f'{self.base_dir}/{id(self)}'
|
41 |
+
def __del__(self):
|
42 |
+
try:
|
43 |
+
os.remove(self.path())
|
44 |
+
except FileNotFoundError: # pragma: no cover
|
45 |
+
pass
|
46 |
+
|
47 |
+
|
48 |
+
def write(fd: int, tensor: torch.Tensor):
|
49 |
+
clone = torch.empty_like(tensor)
|
50 |
+
size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
|
51 |
+
buffer = torch.UntypedStorage(VM_MAX_SIZE)
|
52 |
+
buffer_ptr = buffer.data_ptr()
|
53 |
+
offset = -buffer_ptr % PAGE_SIZE
|
54 |
+
padding = -size % PAGE_SIZE
|
55 |
+
clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
|
56 |
+
clone.copy_(tensor)
|
57 |
+
mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
|
58 |
+
written_bytes = 0
|
59 |
+
while written_bytes < size:
|
60 |
+
written_bytes += os.write(fd, mv[written_bytes:])
|
61 |
+
|
62 |
+
|
63 |
+
def pack_tensors(
|
64 |
+
tensors: set[torch.Tensor],
|
65 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]],
|
66 |
+
offload_dir: str,
|
67 |
+
callback: Callable[[int]] | None = None,
|
68 |
+
):
|
69 |
+
|
70 |
+
callback = (lambda bytes: None) if callback is None else callback
|
71 |
+
|
72 |
+
batches: list[list[TensorWithSizes]] = []
|
73 |
+
big_tensors: list[TensorWithSizes] = []
|
74 |
+
|
75 |
+
tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
|
76 |
+
for tensor in tensors:
|
77 |
+
size = tensor.numel() * tensor.element_size()
|
78 |
+
aligned_size = size + (-size % PAGE_SIZE)
|
79 |
+
tensors_with_sizes += [(tensor, size, aligned_size)]
|
80 |
+
|
81 |
+
current_batch, current_size = [], 0
|
82 |
+
for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
|
83 |
+
if aligned_size > BUFFER_SIZE:
|
84 |
+
big_tensors += [(tensor, size, aligned_size)]
|
85 |
+
continue
|
86 |
+
current_size += aligned_size
|
87 |
+
if current_size > BUFFER_SIZE:
|
88 |
+
batches += [current_batch]
|
89 |
+
current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
|
90 |
+
else:
|
91 |
+
current_batch += [(tensor, size, aligned_size)]
|
92 |
+
|
93 |
+
if current_batch:
|
94 |
+
batches += [current_batch]
|
95 |
+
|
96 |
+
get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
|
97 |
+
batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
|
98 |
+
big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
|
99 |
+
fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
|
100 |
+
|
101 |
+
pack = ZeroGPUTensorPack(
|
102 |
+
base_dir=offload_dir,
|
103 |
+
batches=batches_meta,
|
104 |
+
big_tensors=big_tensors_meta,
|
105 |
+
fakes=fakes_meta,
|
106 |
+
total_size=sum([size for _, size, _ in tensors_with_sizes]),
|
107 |
+
)
|
108 |
+
|
109 |
+
fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
|
110 |
+
try:
|
111 |
+
total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
|
112 |
+
total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
|
113 |
+
if total_asize > 0:
|
114 |
+
os.posix_fallocate(fd, 0, total_asize)
|
115 |
+
for batch in batches:
|
116 |
+
for tensor, size, _ in batch:
|
117 |
+
write(fd, tensor)
|
118 |
+
callback(size)
|
119 |
+
for tensor, size, _ in big_tensors:
|
120 |
+
write(fd, tensor)
|
121 |
+
callback(size)
|
122 |
+
return pack
|
123 |
+
finally:
|
124 |
+
os.close(fd)
|
125 |
+
|
126 |
+
|
127 |
+
def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
|
128 |
+
|
129 |
+
callback = (lambda bytes: None) if callback is None else callback
|
130 |
+
|
131 |
+
free_buffers: Queue[torch.Tensor] = Queue()
|
132 |
+
read_buffers: Queue[torch.Tensor] = Queue()
|
133 |
+
|
134 |
+
for _ in range(BUFFER_COUNT):
|
135 |
+
free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
|
136 |
+
|
137 |
+
def read(fd: int, buffer: torch.Tensor, size: int):
|
138 |
+
mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
|
139 |
+
read_bytes = 0
|
140 |
+
while read_bytes < size:
|
141 |
+
read_bytes += os.readv(fd, [mv[read_bytes:]])
|
142 |
+
|
143 |
+
def disk_to_pin(fd: int):
|
144 |
+
for batch in pack.batches:
|
145 |
+
buffer = free_buffers.get()
|
146 |
+
batch_size = sum([aligned_size for *_, aligned_size in batch])
|
147 |
+
read(fd, buffer, batch_size)
|
148 |
+
read_buffers.put(buffer)
|
149 |
+
for *_, aligned_size in pack.big_tensors:
|
150 |
+
read_bytes = 0
|
151 |
+
while read_bytes < aligned_size:
|
152 |
+
buffer = free_buffers.get()
|
153 |
+
read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
|
154 |
+
read(fd, buffer, read_size)
|
155 |
+
read_buffers.put(buffer)
|
156 |
+
read_bytes += read_size
|
157 |
+
|
158 |
+
def pin_to_cuda():
|
159 |
+
total_duration_in_callback = 0
|
160 |
+
for batch in pack.batches:
|
161 |
+
buffer = read_buffers.get()
|
162 |
+
offset = 0
|
163 |
+
cuda_storages = []
|
164 |
+
for tensor, size, aligned_size in batch:
|
165 |
+
cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
|
166 |
+
offset += aligned_size
|
167 |
+
torch.cuda.synchronize()
|
168 |
+
free_buffers.put(buffer)
|
169 |
+
batch_total_size = 0
|
170 |
+
for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
|
171 |
+
cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
|
172 |
+
cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
|
173 |
+
for fake in pack.fakes[tensor]:
|
174 |
+
fake.data = cuda_tensor
|
175 |
+
batch_total_size += size
|
176 |
+
t0 = time.perf_counter()
|
177 |
+
callback(batch_total_size)
|
178 |
+
total_duration_in_callback += time.perf_counter() - t0
|
179 |
+
for tensor, size, _ in pack.big_tensors:
|
180 |
+
cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
|
181 |
+
offset = 0
|
182 |
+
while offset < size:
|
183 |
+
buffer = read_buffers.get()
|
184 |
+
read_size = min(BUFFER_SIZE, size - offset)
|
185 |
+
cuda_storage[offset:offset+read_size] = buffer[:read_size]
|
186 |
+
offset += read_size
|
187 |
+
torch.cuda.synchronize() # Probably not needed
|
188 |
+
free_buffers.put(buffer)
|
189 |
+
t0 = time.perf_counter()
|
190 |
+
callback(read_size)
|
191 |
+
total_duration_in_callback += time.perf_counter() - t0
|
192 |
+
cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
|
193 |
+
cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
|
194 |
+
for fake in pack.fakes[tensor]:
|
195 |
+
fake.data = cuda_tensor
|
196 |
+
|
197 |
+
debug(f"{total_duration_in_callback=}")
|
198 |
+
|
199 |
+
with ThreadPoolExecutor(2) as e:
|
200 |
+
fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
|
201 |
+
try:
|
202 |
+
futures = [
|
203 |
+
e.submit(copy_context().run, disk_to_pin, fd),
|
204 |
+
e.submit(copy_context().run, pin_to_cuda),
|
205 |
+
]
|
206 |
+
for future in as_completed(futures):
|
207 |
+
future.result()
|
208 |
+
finally:
|
209 |
+
os.close(fd)
|
spaces/zero/torch/patching.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
# pyright: reportPrivateImportUsage=false
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import gc
|
8 |
+
import multiprocessing
|
9 |
+
import os
|
10 |
+
from collections import defaultdict
|
11 |
+
from concurrent.futures import ProcessPoolExecutor
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
+
from contextlib import nullcontext
|
14 |
+
from contextvars import copy_context
|
15 |
+
from types import SimpleNamespace
|
16 |
+
from typing import Any
|
17 |
+
from typing import Callable
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch.overrides import TorchFunctionMode
|
21 |
+
from torch.overrides import resolve_name
|
22 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
23 |
+
from torch.utils._pytree import tree_map_only
|
24 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
25 |
+
|
26 |
+
from ...config import Config
|
27 |
+
from ...utils import malloc_trim
|
28 |
+
from ..tqdm import tqdm
|
29 |
+
from . import bitsandbytes
|
30 |
+
from .packing import ZeroGPUTensorPack
|
31 |
+
from .packing import pack_tensors
|
32 |
+
from .packing import pack_to_cuda
|
33 |
+
from .types import AliasId
|
34 |
+
|
35 |
+
|
36 |
+
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
37 |
+
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
38 |
+
CUDA_TOTAL_MEMORY = 42144366592
|
39 |
+
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
40 |
+
CUDA_DEVICE_CAPABILITY = (8, 0)
|
41 |
+
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
42 |
+
|
43 |
+
OPS_INPUTS_CHECK_NO_RETURN = (
|
44 |
+
torch.Tensor.equal,
|
45 |
+
)
|
46 |
+
|
47 |
+
OPS_INPUT_CHECK_SELF_RETURN = (
|
48 |
+
torch.Tensor.set_, # probably never dispatched
|
49 |
+
torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
|
50 |
+
)
|
51 |
+
|
52 |
+
OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
|
53 |
+
|
54 |
+
_tensor_make_subclass = torch.Tensor._make_subclass
|
55 |
+
_asarray = torch.asarray
|
56 |
+
_cuda_init = torch._C._cuda_init
|
57 |
+
_cuda_exchange_device = torch.cuda._exchange_device
|
58 |
+
_cuda_available = torch.cuda.is_available
|
59 |
+
_cuda_device_count = torch.cuda.device_count
|
60 |
+
_cuda_current_device = torch.cuda.current_device
|
61 |
+
_cuda_mem_get_info = torch.cuda.mem_get_info
|
62 |
+
_cuda_get_device_capability = torch.cuda.get_device_capability
|
63 |
+
_cuda_get_device_properties = torch.cuda.get_device_properties
|
64 |
+
_cuda_get_device_name = torch.cuda.get_device_name
|
65 |
+
|
66 |
+
# PyTorch 2.3
|
67 |
+
_cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
|
68 |
+
|
69 |
+
|
70 |
+
cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
|
71 |
+
|
72 |
+
tensor_packs: list[ZeroGPUTensorPack] = []
|
73 |
+
|
74 |
+
class ZeroGPUTensor(torch.Tensor):
|
75 |
+
pass
|
76 |
+
|
77 |
+
def empty_fake(tensor: torch.Tensor):
|
78 |
+
fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
|
79 |
+
if fake.__class__ != tensor.__class__:
|
80 |
+
fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
|
81 |
+
return fake
|
82 |
+
|
83 |
+
class ZeroGPUFunctionMode(TorchFunctionMode):
|
84 |
+
|
85 |
+
def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
|
86 |
+
|
87 |
+
kwargs = {} if kwargs is None else kwargs
|
88 |
+
|
89 |
+
if func == torch._C._nn._parse_to:
|
90 |
+
return func(*args, **kwargs)
|
91 |
+
|
92 |
+
# Redispatch: tensor.cuda() -> tensor.to(device='cuda')
|
93 |
+
if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
|
94 |
+
memory_format = kwargs.get('memory_format')
|
95 |
+
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
|
96 |
+
'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
|
97 |
+
**({'memory_format': memory_format} if memory_format is not None else {}),
|
98 |
+
})
|
99 |
+
|
100 |
+
# Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
|
101 |
+
if func == torch.Tensor.to and len(args) > 1:
|
102 |
+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args[1:], **kwargs)
|
103 |
+
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
|
104 |
+
'device': device,
|
105 |
+
'dtype': dtype,
|
106 |
+
'memory_format': memory_format,
|
107 |
+
})
|
108 |
+
|
109 |
+
if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
|
110 |
+
self, target = args
|
111 |
+
if target in cuda_aliases:
|
112 |
+
if (target_original := cuda_aliases[target]) is None:
|
113 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
|
114 |
+
original = empty_fake(self)
|
115 |
+
original.data = target_original
|
116 |
+
cuda_aliases[self] = original
|
117 |
+
elif self in cuda_aliases:
|
118 |
+
del cuda_aliases[self]
|
119 |
+
self.data = target
|
120 |
+
return
|
121 |
+
|
122 |
+
if func == torch.Tensor.device.__get__:
|
123 |
+
tensor, = args
|
124 |
+
if tensor in cuda_aliases:
|
125 |
+
return torch.device('cuda', index=0)
|
126 |
+
|
127 |
+
elif func == torch.Tensor.__repr__:
|
128 |
+
tensor, = args
|
129 |
+
if tensor in cuda_aliases:
|
130 |
+
if (original := cuda_aliases[tensor]) is None:
|
131 |
+
original = tensor.to('meta')
|
132 |
+
original_class = original.__class__
|
133 |
+
original.__class__ = ZeroGPUTensor
|
134 |
+
try:
|
135 |
+
return func(original, **kwargs)
|
136 |
+
finally:
|
137 |
+
original.__class__ = original_class
|
138 |
+
|
139 |
+
elif func == torch.Tensor.untyped_storage:
|
140 |
+
tensor, = args
|
141 |
+
if tensor in cuda_aliases:
|
142 |
+
if (original := cuda_aliases[tensor]) is None:
|
143 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
|
144 |
+
res = func(original, **kwargs)
|
145 |
+
res._zerogpu = True
|
146 |
+
return res
|
147 |
+
|
148 |
+
cuda: bool | None = None
|
149 |
+
|
150 |
+
# Handle device kwarg
|
151 |
+
if (device := kwargs.get('device')) is not None:
|
152 |
+
device = torch.device(device)
|
153 |
+
if device.type == 'cuda':
|
154 |
+
kwargs['device'] = torch.device('cpu')
|
155 |
+
cuda = True
|
156 |
+
else:
|
157 |
+
cuda = False
|
158 |
+
|
159 |
+
# Swap fake inputs with original data
|
160 |
+
swapped = {}
|
161 |
+
inputs_are_cuda = set()
|
162 |
+
def swap(tensor: torch.Tensor):
|
163 |
+
nonlocal inputs_are_cuda
|
164 |
+
if tensor not in cuda_aliases:
|
165 |
+
inputs_are_cuda |= {False}
|
166 |
+
return tensor
|
167 |
+
if (original := cuda_aliases[tensor]) is None:
|
168 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
|
169 |
+
swapped[original] = tensor
|
170 |
+
inputs_are_cuda |= {True}
|
171 |
+
return original
|
172 |
+
args_ = tree_map_only(torch.Tensor, swap, args)
|
173 |
+
kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
|
174 |
+
if inputs_are_cuda == {True}:
|
175 |
+
if cuda is not False:
|
176 |
+
cuda = True
|
177 |
+
|
178 |
+
res = func(*args_, **kwargs_)
|
179 |
+
|
180 |
+
# Re-generate swapped fakes in case of mutation
|
181 |
+
for original, fake in swapped.items():
|
182 |
+
fake.data = empty_fake(original)
|
183 |
+
|
184 |
+
# Special case for Tensor indexing where only 'self' matters
|
185 |
+
if func in {
|
186 |
+
torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
|
187 |
+
torch.Tensor.__getitem__, # PyTorch 2.4+
|
188 |
+
}:
|
189 |
+
self = args[0]
|
190 |
+
cuda = self in cuda_aliases
|
191 |
+
inputs_are_cuda = {cuda}
|
192 |
+
|
193 |
+
# Emulate device check
|
194 |
+
if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
|
195 |
+
self = None
|
196 |
+
if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
|
197 |
+
self = args_[0]
|
198 |
+
# Only raise if func does not return its first input (Tensor.copy_)
|
199 |
+
if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
|
200 |
+
if inputs_are_cuda == {True, False}:
|
201 |
+
raise RuntimeError(
|
202 |
+
"Expected all tensors to be on the same device, "
|
203 |
+
"but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
|
204 |
+
)
|
205 |
+
|
206 |
+
# Register output
|
207 |
+
def register(tensor: torch.Tensor):
|
208 |
+
if tensor in swapped and cuda is not False:
|
209 |
+
return swapped[tensor]
|
210 |
+
if cuda is not True:
|
211 |
+
return tensor
|
212 |
+
fake = empty_fake(tensor)
|
213 |
+
cuda_aliases[fake] = tensor
|
214 |
+
return fake
|
215 |
+
|
216 |
+
return tree_map_only(torch.Tensor, register, res)
|
217 |
+
|
218 |
+
# When enabling DispatchMode, some aten ops are dispatched to FunctionMode
|
219 |
+
# We are using it for aten.alias.default and aten.set_.source_Tensor
|
220 |
+
class DefaultDispatchMode(TorchDispatchMode):
|
221 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
|
222 |
+
return func(*args, **(kwargs or {}))
|
223 |
+
|
224 |
+
|
225 |
+
function_mode = ZeroGPUFunctionMode()
|
226 |
+
dispatch_mode = DefaultDispatchMode()
|
227 |
+
|
228 |
+
|
229 |
+
def _untyped_storage_new_register(*args, **kwargs):
|
230 |
+
cuda = False
|
231 |
+
if (device := kwargs.get('device')) is not None and device.type == 'cuda':
|
232 |
+
cuda = True
|
233 |
+
del kwargs['device']
|
234 |
+
storage = torch._C.StorageBase.__new__(*args, **kwargs)
|
235 |
+
if cuda:
|
236 |
+
storage._zerogpu = True
|
237 |
+
return storage
|
238 |
+
|
239 |
+
@property
|
240 |
+
def _untyped_storage_device(self):
|
241 |
+
if hasattr(self, '_zerogpu'):
|
242 |
+
return torch.device('cuda', index=0)
|
243 |
+
return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
|
244 |
+
|
245 |
+
# Force dispatch
|
246 |
+
def _tensor_make_subclass_function_mode(*args, **kwargs):
|
247 |
+
with torch._C.DisableTorchFunction():
|
248 |
+
return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
|
249 |
+
def _asarray_function_mode(*args, **kwargs):
|
250 |
+
with torch._C.DisableTorchFunction():
|
251 |
+
return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
|
252 |
+
|
253 |
+
def _cuda_init_raise():
|
254 |
+
raise RuntimeError(
|
255 |
+
"CUDA must not be initialized in the main process "
|
256 |
+
"on Spaces with Stateless GPU environment.\n"
|
257 |
+
"You can look at this Stacktrace to find out "
|
258 |
+
"which part of your code triggered a CUDA init"
|
259 |
+
)
|
260 |
+
|
261 |
+
def _cuda_dummy_exchange_device(device):
|
262 |
+
assert device in {-1, 0}
|
263 |
+
return device
|
264 |
+
|
265 |
+
def patch():
|
266 |
+
function_mode.__enter__()
|
267 |
+
dispatch_mode.__enter__()
|
268 |
+
# TODO: only patch bellow methods on current Thread to be consistent with TorchModes
|
269 |
+
# (or hijack threading.Thread.__init__ to force Modes on all threads)
|
270 |
+
torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
|
271 |
+
torch.UntypedStorage.__new__ = _untyped_storage_new_register
|
272 |
+
torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
|
273 |
+
torch.asarray = _asarray_function_mode
|
274 |
+
torch._C._cuda_init = _cuda_init_raise
|
275 |
+
torch.cuda._exchange_device = _cuda_dummy_exchange_device
|
276 |
+
torch.cuda.is_available = lambda: True
|
277 |
+
torch.cuda.device_count = lambda: 1
|
278 |
+
torch.cuda.current_device = lambda: 0
|
279 |
+
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
280 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
281 |
+
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
282 |
+
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
283 |
+
# PyTorch 2.3
|
284 |
+
if _cuda_maybe_exchange_device is not None: # pragma: no cover
|
285 |
+
setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
|
286 |
+
bitsandbytes.patch()
|
287 |
+
|
288 |
+
def unpatch():
|
289 |
+
try:
|
290 |
+
dispatch_mode.__exit__(None, None, None)
|
291 |
+
function_mode.__exit__(None, None, None)
|
292 |
+
except RuntimeError:
|
293 |
+
pass # patch() and unpatch() called from != threads
|
294 |
+
torch.Tensor._make_subclass = _tensor_make_subclass
|
295 |
+
torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
|
296 |
+
torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
|
297 |
+
torch.asarray = _asarray
|
298 |
+
torch._C._cuda_init = _cuda_init
|
299 |
+
torch.cuda._exchange_device = _cuda_exchange_device
|
300 |
+
torch.cuda.is_available = _cuda_available
|
301 |
+
torch.cuda.device_count = _cuda_device_count
|
302 |
+
torch.cuda.current_device = _cuda_current_device
|
303 |
+
torch.cuda.mem_get_info = _cuda_mem_get_info
|
304 |
+
torch.cuda.get_device_capability = _cuda_get_device_capability
|
305 |
+
torch.cuda.get_device_properties = _cuda_get_device_properties
|
306 |
+
torch.cuda.get_device_name = _cuda_get_device_name
|
307 |
+
# PyTorch 2.3
|
308 |
+
if _cuda_maybe_exchange_device is not None: # pragma: no cover
|
309 |
+
setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
|
310 |
+
bitsandbytes.unpatch()
|
311 |
+
|
312 |
+
|
313 |
+
def _total_unpacked_size():
|
314 |
+
tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
|
315 |
+
deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
|
316 |
+
return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
|
317 |
+
|
318 |
+
|
319 |
+
def _pack(offload_dir: str):
|
320 |
+
# Pack to disk
|
321 |
+
originals: set[torch.Tensor] = set()
|
322 |
+
originals_dedup: dict[AliasId, torch.Tensor] = {}
|
323 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
|
324 |
+
for fake, original in cuda_aliases.items():
|
325 |
+
# TODO filter-out sparse Tensors
|
326 |
+
if original is not None:
|
327 |
+
original_id = AliasId.from_tensor(original)
|
328 |
+
if original_id not in originals_dedup:
|
329 |
+
originals_dedup[original_id] = original
|
330 |
+
originals |= {original}
|
331 |
+
fakes[originals_dedup[original_id]] += [fake]
|
332 |
+
progress = tqdm(
|
333 |
+
total=_total_unpacked_size(),
|
334 |
+
unit='B',
|
335 |
+
unit_scale=True,
|
336 |
+
desc="ZeroGPU tensors packing",
|
337 |
+
) if tqdm is not None else nullcontext()
|
338 |
+
with progress as progress:
|
339 |
+
update = progress.update if progress is not None else lambda _: None
|
340 |
+
pack = pack_tensors(originals, fakes, offload_dir, callback=update)
|
341 |
+
tensor_packs.append(pack)
|
342 |
+
# Free memory
|
343 |
+
for fake_list in fakes.values():
|
344 |
+
for fake in fake_list:
|
345 |
+
cuda_aliases[fake] = None
|
346 |
+
|
347 |
+
def pack():
|
348 |
+
_pack(Config.zerogpu_offload_dir)
|
349 |
+
gc.collect()
|
350 |
+
malloc_trim()
|
351 |
+
|
352 |
+
def init(nvidia_uuid: str):
|
353 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
354 |
+
torch.Tensor([0]).cuda()
|
355 |
+
|
356 |
+
def size():
|
357 |
+
return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
|
358 |
+
|
359 |
+
def _move(callback: Callable[[int]] | None = None):
|
360 |
+
callback = callback if callback is not None else lambda _: None
|
361 |
+
# CPU -> CUDA
|
362 |
+
moved: dict[AliasId, torch.Tensor] = {}
|
363 |
+
for fake, original in cuda_aliases.items():
|
364 |
+
if original is not None:
|
365 |
+
original_id = AliasId.from_tensor(original)
|
366 |
+
if original_id not in moved:
|
367 |
+
moved[original_id] = original.cuda()
|
368 |
+
callback(fake.numel() * fake.element_size())
|
369 |
+
for fake, original in cuda_aliases.items():
|
370 |
+
if original is not None:
|
371 |
+
fake.data = moved[AliasId.from_tensor(original)]
|
372 |
+
# Disk -> CUDA
|
373 |
+
for tensor_pack in tensor_packs:
|
374 |
+
pack_to_cuda(tensor_pack, callback=callback)
|
375 |
+
bitsandbytes.move()
|
376 |
+
|
377 |
+
def move(callback: Callable[[int]] | None = None):
|
378 |
+
callback = callback if callback is not None else lambda _: None
|
379 |
+
with ThreadPoolExecutor(1) as e:
|
380 |
+
e.submit(copy_context().run, _move, callback=callback).result()
|
381 |
+
torch.cuda.synchronize()
|
382 |
+
|
383 |
+
def is_in_bad_fork():
|
384 |
+
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
385 |
+
f = e.submit(torch.cuda._is_in_bad_fork)
|
386 |
+
return f.result()
|
spaces/zero/torch/patching_legacy.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
# pyright: reportPrivateImportUsage=false
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
import multiprocessing
|
8 |
+
import os
|
9 |
+
from concurrent.futures import ProcessPoolExecutor
|
10 |
+
from contextlib import suppress
|
11 |
+
from functools import partial
|
12 |
+
from types import SimpleNamespace
|
13 |
+
from typing import Any
|
14 |
+
from typing import Callable
|
15 |
+
from typing import Optional
|
16 |
+
from typing import Tuple
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
20 |
+
|
21 |
+
from ...config import Config
|
22 |
+
from . import bitsandbytes
|
23 |
+
|
24 |
+
|
25 |
+
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
26 |
+
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
27 |
+
CUDA_TOTAL_MEMORY = 42144366592
|
28 |
+
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
29 |
+
CUDA_DEVICE_CAPABILITY = (8, 0)
|
30 |
+
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
31 |
+
|
32 |
+
GENERIC_METHOD_NAMES = [
|
33 |
+
'arange',
|
34 |
+
'as_tensor',
|
35 |
+
'asarray',
|
36 |
+
'bartlett_window',
|
37 |
+
'blackman_window',
|
38 |
+
'empty',
|
39 |
+
'empty_like',
|
40 |
+
'empty_strided',
|
41 |
+
'eye',
|
42 |
+
'full',
|
43 |
+
'full_like',
|
44 |
+
'hamming_window',
|
45 |
+
'hann_window',
|
46 |
+
'kaiser_window',
|
47 |
+
'linspace',
|
48 |
+
'logspace',
|
49 |
+
'ones',
|
50 |
+
'ones_like',
|
51 |
+
'rand',
|
52 |
+
'rand_like',
|
53 |
+
'randint',
|
54 |
+
'randint_like',
|
55 |
+
'randn',
|
56 |
+
'randn_like',
|
57 |
+
'randperm',
|
58 |
+
'range',
|
59 |
+
'sparse_bsc_tensor',
|
60 |
+
'sparse_bsr_tensor',
|
61 |
+
'sparse_compressed_tensor',
|
62 |
+
'sparse_coo_tensor',
|
63 |
+
'sparse_csc_tensor',
|
64 |
+
'sparse_csr_tensor',
|
65 |
+
'tensor',
|
66 |
+
'tril_indices',
|
67 |
+
'triu_indices',
|
68 |
+
'zeros',
|
69 |
+
'zeros_like',
|
70 |
+
]
|
71 |
+
|
72 |
+
|
73 |
+
TO_CUDA = (torch.device('cuda'), None, False, None)
|
74 |
+
|
75 |
+
_tensor__deepcopy__ = torch.Tensor.__deepcopy__
|
76 |
+
_tensor_to = torch.Tensor.to
|
77 |
+
_tensor_cuda = torch.Tensor.cuda
|
78 |
+
_tensor_cpu = torch.Tensor.cpu
|
79 |
+
_torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
|
80 |
+
_cuda_init = torch._C._cuda_init
|
81 |
+
_cuda_available = torch.cuda.is_available
|
82 |
+
_cuda_device_count = torch.cuda.device_count
|
83 |
+
_cuda_current_device = torch.cuda.current_device
|
84 |
+
_cuda_mem_get_info = torch.cuda.mem_get_info
|
85 |
+
_cuda_get_device_capability = torch.cuda.get_device_capability
|
86 |
+
_cuda_get_device_properties = torch.cuda.get_device_properties
|
87 |
+
_cuda_get_device_name = torch.cuda.get_device_name
|
88 |
+
|
89 |
+
TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
|
90 |
+
|
91 |
+
to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
|
92 |
+
|
93 |
+
def _tensor_new_register(*args, **kwargs):
|
94 |
+
new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
|
95 |
+
if (base_tensor := new_tensor._base) is not None:
|
96 |
+
if base_tensor in to_ops:
|
97 |
+
to_ops[new_tensor] = to_ops[base_tensor]
|
98 |
+
return new_tensor
|
99 |
+
|
100 |
+
def _tensor_deepcopy_register(self: torch.Tensor, memo):
|
101 |
+
new_tensor = _tensor__deepcopy__(self, memo)
|
102 |
+
if isinstance(new_tensor, torch.Tensor):
|
103 |
+
if self in to_ops:
|
104 |
+
to_ops[new_tensor] = to_ops[self]
|
105 |
+
return new_tensor
|
106 |
+
|
107 |
+
@property
|
108 |
+
def _tensor_device_property(self: torch.Tensor):
|
109 |
+
if self in to_ops:
|
110 |
+
return torch.device(type='cuda', index=0)
|
111 |
+
del torch.Tensor.device
|
112 |
+
try:
|
113 |
+
return self.device
|
114 |
+
finally:
|
115 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
116 |
+
|
117 |
+
@property
|
118 |
+
def _tensor_dtype_property(self: torch.Tensor):
|
119 |
+
if self in to_ops:
|
120 |
+
if (to_dtype := to_ops[self][1]) is not None:
|
121 |
+
return to_dtype
|
122 |
+
del torch.Tensor.dtype
|
123 |
+
try:
|
124 |
+
return self.dtype
|
125 |
+
finally:
|
126 |
+
torch.Tensor.dtype = _tensor_dtype_property # type: ignore
|
127 |
+
|
128 |
+
def _to_op_register(self: torch.Tensor, *args, **kwargs):
|
129 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
130 |
+
device, dtype, *_ = parsed
|
131 |
+
try:
|
132 |
+
to_args = to_ops.pop(self)
|
133 |
+
except KeyError:
|
134 |
+
to_args = None
|
135 |
+
if device is None: # pyright: ignore [reportUnnecessaryComparison]
|
136 |
+
if to_args is not None:
|
137 |
+
to_ops[self] = (to_args[0], dtype, *to_args[2:])
|
138 |
+
return self
|
139 |
+
return _tensor_to(self, *args, **kwargs)
|
140 |
+
if device.type != 'cuda':
|
141 |
+
if to_args is not None:
|
142 |
+
if (to_dtype := to_args[1]) is not None:
|
143 |
+
kwargs = {'dtype': to_dtype, **kwargs}
|
144 |
+
return _tensor_to(self, *args, **kwargs)
|
145 |
+
to_ops[self] = parsed
|
146 |
+
return self
|
147 |
+
|
148 |
+
def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
|
149 |
+
if device is None:
|
150 |
+
return True
|
151 |
+
if isinstance(device, int):
|
152 |
+
return True
|
153 |
+
if isinstance(device, str):
|
154 |
+
device = torch.device(device)
|
155 |
+
return device.type == 'cuda'
|
156 |
+
|
157 |
+
def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
|
158 |
+
if not _cuda_op_arg_check(device):
|
159 |
+
# Let PyTorch handle the fail
|
160 |
+
return _tensor_cuda(self, device, **kwargs)
|
161 |
+
to_ops[self] = TO_CUDA
|
162 |
+
return self
|
163 |
+
|
164 |
+
def _cpu_op_remove(self: torch.Tensor, **kwargs):
|
165 |
+
try:
|
166 |
+
to_args = to_ops.pop(self)
|
167 |
+
except KeyError:
|
168 |
+
to_args = None
|
169 |
+
if to_args is not None:
|
170 |
+
if (to_dtype := to_args[1]) is not None:
|
171 |
+
return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
|
172 |
+
return _tensor_cpu(self, **kwargs)
|
173 |
+
|
174 |
+
def _cuda_init_raise():
|
175 |
+
raise RuntimeError(
|
176 |
+
"CUDA must not be initialized in the main process "
|
177 |
+
"on Spaces with Stateless GPU environment.\n"
|
178 |
+
"You can look at this Stacktrace to find out "
|
179 |
+
"which part of your code triggered a CUDA init"
|
180 |
+
)
|
181 |
+
|
182 |
+
def _generic_method_register(name: str, *args: Any, **kwargs: Any):
|
183 |
+
try:
|
184 |
+
device = torch.device(kwargs.get('device', "cpu"))
|
185 |
+
except Exception:
|
186 |
+
return _torch_generics[name](*args, **kwargs)
|
187 |
+
if device.type != 'cuda':
|
188 |
+
return _torch_generics[name](*args, **kwargs)
|
189 |
+
tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
|
190 |
+
to_ops[tensor] = TO_CUDA
|
191 |
+
return tensor
|
192 |
+
|
193 |
+
def patch():
|
194 |
+
torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
|
195 |
+
torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
|
196 |
+
torch.Tensor.to = _to_op_register # type: ignore
|
197 |
+
torch.Tensor.cuda = _cuda_op_register # type: ignore
|
198 |
+
torch.Tensor.cpu = _cpu_op_remove # type: ignore
|
199 |
+
if Config.zero_patch_torch_device:
|
200 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
201 |
+
torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
|
202 |
+
for name in GENERIC_METHOD_NAMES:
|
203 |
+
setattr(torch, name, partial(_generic_method_register, name))
|
204 |
+
torch._C._cuda_init = _cuda_init_raise
|
205 |
+
torch.cuda.is_available = lambda: True
|
206 |
+
torch.cuda.device_count = lambda: 1
|
207 |
+
torch.cuda.current_device = lambda: 0
|
208 |
+
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
209 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
210 |
+
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
211 |
+
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
212 |
+
bitsandbytes.patch()
|
213 |
+
|
214 |
+
def unpatch():
|
215 |
+
torch.Tensor.__deepcopy__ = _tensor__deepcopy__
|
216 |
+
with suppress(AttributeError):
|
217 |
+
del torch.Tensor.__new__
|
218 |
+
torch.Tensor.to = _tensor_to
|
219 |
+
torch.Tensor.cuda = _tensor_cuda
|
220 |
+
torch.Tensor.cpu = _tensor_cpu
|
221 |
+
with suppress(AttributeError):
|
222 |
+
del torch.Tensor.device
|
223 |
+
with suppress(AttributeError):
|
224 |
+
del torch.Tensor.dtype
|
225 |
+
for name in GENERIC_METHOD_NAMES:
|
226 |
+
setattr(torch, name, _torch_generics[name])
|
227 |
+
torch._C._cuda_init = _cuda_init
|
228 |
+
torch.cuda.is_available = _cuda_available
|
229 |
+
torch.cuda.device_count = _cuda_device_count
|
230 |
+
torch.cuda.current_device = _cuda_current_device
|
231 |
+
torch.cuda.mem_get_info = _cuda_mem_get_info
|
232 |
+
torch.cuda.get_device_capability = _cuda_get_device_capability
|
233 |
+
torch.cuda.get_device_properties = _cuda_get_device_properties
|
234 |
+
torch.cuda.get_device_name = _cuda_get_device_name
|
235 |
+
bitsandbytes.unpatch()
|
236 |
+
|
237 |
+
def pack():
|
238 |
+
pass
|
239 |
+
|
240 |
+
def init(nvidia_uuid: str):
|
241 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
242 |
+
torch.Tensor([0]).cuda() # CUDA init
|
243 |
+
|
244 |
+
def size():
|
245 |
+
return 0
|
246 |
+
|
247 |
+
def move(callback: Callable[[int]] | None = None):
|
248 |
+
for op in to_ops.items():
|
249 |
+
tensor, parsed_args = op
|
250 |
+
_, dtype, _, memory_format = parsed_args
|
251 |
+
tensor.data = _tensor_to(tensor,
|
252 |
+
device='cuda',
|
253 |
+
dtype=dtype,
|
254 |
+
memory_format=memory_format,
|
255 |
+
) # type: ignore
|
256 |
+
bitsandbytes.move()
|
257 |
+
torch.cuda.synchronize()
|
258 |
+
|
259 |
+
def is_in_bad_fork():
|
260 |
+
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
261 |
+
f = e.submit(torch.cuda._is_in_bad_fork)
|
262 |
+
return f.result()
|
263 |
+
|
264 |
+
def disable_cuda_intercept():
|
265 |
+
torch.Tensor.to = _tensor_to
|
266 |
+
torch.Tensor.cuda = _tensor_cuda
|
spaces/zero/torch/types.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import NamedTuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
class AliasId(NamedTuple):
|
11 |
+
data_ptr: int
|
12 |
+
dtype: torch.dtype
|
13 |
+
shape: tuple[int, ...]
|
14 |
+
stride: tuple[int, ...]
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def from_tensor(cls, tensor: torch.Tensor):
|
18 |
+
return cls(
|
19 |
+
tensor.data_ptr(),
|
20 |
+
tensor.dtype,
|
21 |
+
tensor.shape,
|
22 |
+
tensor.stride(),
|
23 |
+
)
|
spaces/zero/tqdm.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
|
4 |
+
from multiprocessing.synchronize import RLock as MultiprocessingRLock
|
5 |
+
|
6 |
+
|
7 |
+
try:
|
8 |
+
from tqdm import tqdm as _tqdm
|
9 |
+
except ImportError: # pragma: no cover
|
10 |
+
_tqdm = None
|
11 |
+
|
12 |
+
|
13 |
+
def remove_tqdm_multiprocessing_lock():
|
14 |
+
if _tqdm is None: # pragma: no cover
|
15 |
+
return
|
16 |
+
tqdm_lock = _tqdm.get_lock()
|
17 |
+
assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
|
18 |
+
tqdm_lock.locks = [
|
19 |
+
lock for lock in tqdm_lock.locks
|
20 |
+
if not isinstance(lock, MultiprocessingRLock)
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
tqdm = _tqdm
|
spaces/zero/types.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from datetime import timedelta
|
8 |
+
from typing import Any
|
9 |
+
from typing import Dict
|
10 |
+
from typing import Tuple
|
11 |
+
from typing import TypedDict
|
12 |
+
from typing_extensions import Callable
|
13 |
+
from typing_extensions import Generic
|
14 |
+
from typing_extensions import ParamSpec
|
15 |
+
from typing_extensions import TypeAlias
|
16 |
+
from typing_extensions import TypeVar
|
17 |
+
|
18 |
+
|
19 |
+
Params = Tuple[Tuple[object, ...], Dict[str, Any]]
|
20 |
+
Res = TypeVar('Res')
|
21 |
+
Param = ParamSpec('Param')
|
22 |
+
|
23 |
+
class EmptyKwargs(TypedDict):
|
24 |
+
pass
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class OkResult(Generic[Res]):
|
28 |
+
value: Res
|
29 |
+
@dataclass
|
30 |
+
class ExceptionResult:
|
31 |
+
value: Exception
|
32 |
+
@dataclass
|
33 |
+
class AbortedResult:
|
34 |
+
pass
|
35 |
+
@dataclass
|
36 |
+
class EndResult:
|
37 |
+
pass
|
38 |
+
@dataclass
|
39 |
+
class GradioQueueEvent:
|
40 |
+
method_name: str
|
41 |
+
args: tuple[Any, ...]
|
42 |
+
kwargs: dict[str, Any]
|
43 |
+
|
44 |
+
RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
|
45 |
+
GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
|
46 |
+
YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
|
47 |
+
|
48 |
+
Duration: TypeAlias = "int | timedelta"
|
49 |
+
DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
|
spaces/zero/wrappers.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"""
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import multiprocessing
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
import traceback
|
9 |
+
import warnings
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
from contextlib import nullcontext
|
12 |
+
from contextvars import copy_context
|
13 |
+
from datetime import timedelta
|
14 |
+
from functools import partial
|
15 |
+
from functools import wraps
|
16 |
+
from multiprocessing.context import ForkProcess
|
17 |
+
from pickle import PicklingError
|
18 |
+
from queue import Empty
|
19 |
+
from queue import Queue as ThreadQueue
|
20 |
+
from threading import Thread
|
21 |
+
from typing import TYPE_CHECKING
|
22 |
+
from typing import Callable
|
23 |
+
from typing import Generator
|
24 |
+
from typing import Generic
|
25 |
+
from typing_extensions import assert_never
|
26 |
+
|
27 |
+
import psutil
|
28 |
+
|
29 |
+
from ..config import Config
|
30 |
+
from ..utils import debug
|
31 |
+
from ..utils import drop_params
|
32 |
+
from ..utils import gradio_request_var
|
33 |
+
from ..utils import SimpleQueue as Queue
|
34 |
+
from . import client
|
35 |
+
from . import torch
|
36 |
+
from .api import AllowToken
|
37 |
+
from .api import NvidiaIndex
|
38 |
+
from .api import NvidiaUUID
|
39 |
+
from .gradio import GradioPartialContext
|
40 |
+
from .gradio import get_server_port
|
41 |
+
from .gradio import patch_gradio_queue
|
42 |
+
from .gradio import try_process_queue_event
|
43 |
+
from .tqdm import remove_tqdm_multiprocessing_lock
|
44 |
+
from .tqdm import tqdm
|
45 |
+
from .types import * # TODO: Please don't do that
|
46 |
+
|
47 |
+
|
48 |
+
GENERATOR_GLOBAL_TIMEOUT = 20 * 60
|
49 |
+
|
50 |
+
SPAWN_PROGRESS_CLEANUP = 0.1
|
51 |
+
SPAWN_PROGRESS_INIT = 0.1
|
52 |
+
|
53 |
+
|
54 |
+
Process = multiprocessing.get_context('fork').Process
|
55 |
+
forked = False
|
56 |
+
|
57 |
+
|
58 |
+
class Worker(Generic[Res]):
|
59 |
+
process: ForkProcess
|
60 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]]
|
61 |
+
res_queue: Queue[Res | None]
|
62 |
+
_sentinel: Thread
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
target: Callable[[
|
67 |
+
Queue[tuple[Params, GradioPartialContext]],
|
68 |
+
Queue[Res | None],
|
69 |
+
AllowToken,
|
70 |
+
NvidiaUUID,
|
71 |
+
list[int],
|
72 |
+
], None],
|
73 |
+
allow_token: str,
|
74 |
+
nvidia_uuid: str,
|
75 |
+
):
|
76 |
+
self._sentinel = Thread(target=self._close_on_exit, daemon=True)
|
77 |
+
self.arg_queue = Queue()
|
78 |
+
self.res_queue = Queue()
|
79 |
+
debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
|
80 |
+
debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
|
81 |
+
if (server_port := get_server_port()) is not None:
|
82 |
+
fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
|
83 |
+
debug(f"{fds=}")
|
84 |
+
else:
|
85 |
+
warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
|
86 |
+
fds = []
|
87 |
+
args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
|
88 |
+
if TYPE_CHECKING:
|
89 |
+
target(*args)
|
90 |
+
self.process = Process(
|
91 |
+
target=target,
|
92 |
+
args=args,
|
93 |
+
daemon=True,
|
94 |
+
)
|
95 |
+
self.process.start()
|
96 |
+
self._sentinel.start()
|
97 |
+
|
98 |
+
def _close_on_exit(self):
|
99 |
+
self.process.join()
|
100 |
+
self.arg_queue.close()
|
101 |
+
self.res_queue.wlock_release()
|
102 |
+
self.res_queue.put(None)
|
103 |
+
|
104 |
+
|
105 |
+
def worker_init(
|
106 |
+
res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
|
107 |
+
allow_token: str,
|
108 |
+
nvidia_uuid: str,
|
109 |
+
fds: list[int],
|
110 |
+
) -> None | ExceptionResult:
|
111 |
+
# Immediately close file descriptors
|
112 |
+
for fd in fds:
|
113 |
+
try:
|
114 |
+
os.close(fd)
|
115 |
+
except Exception as e: # pragma: no cover
|
116 |
+
if isinstance(e, OSError) and e.errno == 9:
|
117 |
+
continue
|
118 |
+
traceback.print_exc()
|
119 |
+
return ExceptionResult(e)
|
120 |
+
progress = nullcontext()
|
121 |
+
if tqdm is not None and Config.zero_gpu_v2:
|
122 |
+
progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
|
123 |
+
try: # Unrecoverable init part
|
124 |
+
patch_gradio_queue(res_queue)
|
125 |
+
with progress as progress:
|
126 |
+
current_progress = 0 # Gradio does not support float progress updates
|
127 |
+
def update(n: float):
|
128 |
+
nonlocal current_progress
|
129 |
+
current_progress += n
|
130 |
+
if progress is not None:
|
131 |
+
progress.update(round(current_progress * 100) - progress.n)
|
132 |
+
client.allow(allow_token)
|
133 |
+
update(SPAWN_PROGRESS_CLEANUP)
|
134 |
+
torch.unpatch()
|
135 |
+
torch.init(nvidia_uuid)
|
136 |
+
update(SPAWN_PROGRESS_INIT)
|
137 |
+
callback = None
|
138 |
+
if (transfer_size := torch.size()) > 0:
|
139 |
+
remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
|
140 |
+
callback = lambda n: update(n * remaining / transfer_size)
|
141 |
+
torch.move(callback=callback)
|
142 |
+
except Exception as e: # pragma: no cover
|
143 |
+
traceback.print_exc()
|
144 |
+
return ExceptionResult(e)
|
145 |
+
try:
|
146 |
+
remove_tqdm_multiprocessing_lock()
|
147 |
+
except Exception: # pragma: no cover
|
148 |
+
print("Error while trying to remove tqdm mp_lock:")
|
149 |
+
traceback.print_exc()
|
150 |
+
|
151 |
+
|
152 |
+
def process_duration(duration: Duration | None):
|
153 |
+
if duration is None or isinstance(duration, timedelta):
|
154 |
+
return duration
|
155 |
+
return timedelta(seconds=duration)
|
156 |
+
|
157 |
+
|
158 |
+
def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
|
159 |
+
if not callable(duration):
|
160 |
+
return duration
|
161 |
+
return duration(*args, **kwargs)
|
162 |
+
|
163 |
+
|
164 |
+
def regular_function_wrapper(
|
165 |
+
task: Callable[Param, Res],
|
166 |
+
duration: DynamicDuration[Param],
|
167 |
+
) -> Callable[Param, Res]:
|
168 |
+
|
169 |
+
import gradio as gr
|
170 |
+
|
171 |
+
request_var = gradio_request_var()
|
172 |
+
workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
|
173 |
+
task_id = id(task)
|
174 |
+
|
175 |
+
@wraps(task)
|
176 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
|
177 |
+
|
178 |
+
if forked:
|
179 |
+
return task(*args, **kwargs)
|
180 |
+
|
181 |
+
request = request_var.get()
|
182 |
+
duration_ = static_duration(duration, *args, **kwargs)
|
183 |
+
duration_ = process_duration(duration_)
|
184 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
|
185 |
+
allow_token = schedule_response.allowToken
|
186 |
+
nvidia_index = schedule_response.nvidiaIndex
|
187 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
188 |
+
release = partial(client.release, allow_token)
|
189 |
+
|
190 |
+
try:
|
191 |
+
worker = workers.pop(nvidia_index)
|
192 |
+
except KeyError:
|
193 |
+
worker = None
|
194 |
+
|
195 |
+
if worker is not None and worker.process.is_alive() and schedule_response.idle:
|
196 |
+
assert worker.arg_queue.empty()
|
197 |
+
assert worker.res_queue.empty()
|
198 |
+
else:
|
199 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
200 |
+
|
201 |
+
try:
|
202 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
203 |
+
except PicklingError: # TODO: detailed serialization diagnostic
|
204 |
+
release(fail=True)
|
205 |
+
raise
|
206 |
+
|
207 |
+
while True:
|
208 |
+
res = worker.res_queue.get()
|
209 |
+
if res is None:
|
210 |
+
release(fail=True, allow_404=True)
|
211 |
+
raise gr.Error("GPU task aborted")
|
212 |
+
if isinstance(res, ExceptionResult):
|
213 |
+
release(fail=True)
|
214 |
+
raise res.value
|
215 |
+
if isinstance(res, OkResult):
|
216 |
+
release()
|
217 |
+
workers[nvidia_index] = worker
|
218 |
+
return res.value
|
219 |
+
if isinstance(res, GradioQueueEvent):
|
220 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
221 |
+
continue
|
222 |
+
assert_never(res)
|
223 |
+
|
224 |
+
|
225 |
+
def thread_wrapper(
|
226 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
227 |
+
res_queue: Queue[RegularResQueueResult[Res] | None],
|
228 |
+
allow_token: str,
|
229 |
+
nvidia_uuid: str,
|
230 |
+
fds: list[int],
|
231 |
+
):
|
232 |
+
global forked
|
233 |
+
forked = True
|
234 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
235 |
+
initialized = False
|
236 |
+
while True:
|
237 |
+
try:
|
238 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
239 |
+
except OSError:
|
240 |
+
break
|
241 |
+
if not initialized:
|
242 |
+
if (res := worker_init(
|
243 |
+
res_queue=res_queue,
|
244 |
+
allow_token=allow_token,
|
245 |
+
nvidia_uuid=nvidia_uuid,
|
246 |
+
fds=fds,
|
247 |
+
)) is not None:
|
248 |
+
res_queue.put(res)
|
249 |
+
return
|
250 |
+
initialized = True
|
251 |
+
GradioPartialContext.apply(gradio_context)
|
252 |
+
context = copy_context()
|
253 |
+
with ThreadPoolExecutor() as executor:
|
254 |
+
future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
|
255 |
+
try:
|
256 |
+
res = future.result()
|
257 |
+
except Exception as e:
|
258 |
+
traceback.print_exc()
|
259 |
+
res = ExceptionResult(e)
|
260 |
+
else:
|
261 |
+
res = OkResult(res)
|
262 |
+
try:
|
263 |
+
res_queue.put(res)
|
264 |
+
except PicklingError as e:
|
265 |
+
res_queue.put(ExceptionResult(e))
|
266 |
+
|
267 |
+
# https://github.com/python/cpython/issues/91002
|
268 |
+
if not hasattr(task, '__annotations__'):
|
269 |
+
gradio_handler.__annotations__ = {}
|
270 |
+
|
271 |
+
return gradio_handler
|
272 |
+
|
273 |
+
|
274 |
+
def generator_function_wrapper(
|
275 |
+
task: Callable[Param, Generator[Res, None, None]],
|
276 |
+
duration: DynamicDuration[Param],
|
277 |
+
) -> Callable[Param, Generator[Res, None, None]]:
|
278 |
+
|
279 |
+
import gradio as gr
|
280 |
+
|
281 |
+
request_var = gradio_request_var()
|
282 |
+
workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
|
283 |
+
task_id = id(task)
|
284 |
+
|
285 |
+
@wraps(task)
|
286 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
|
287 |
+
|
288 |
+
if forked:
|
289 |
+
yield from task(*args, **kwargs)
|
290 |
+
return
|
291 |
+
|
292 |
+
request = request_var.get()
|
293 |
+
duration_ = static_duration(duration, *args, **kwargs)
|
294 |
+
duration_ = process_duration(duration_)
|
295 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
|
296 |
+
allow_token = schedule_response.allowToken
|
297 |
+
nvidia_index = schedule_response.nvidiaIndex
|
298 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
299 |
+
release = partial(client.release, allow_token)
|
300 |
+
|
301 |
+
try:
|
302 |
+
worker = workers.pop(nvidia_index)
|
303 |
+
except KeyError:
|
304 |
+
worker = None
|
305 |
+
|
306 |
+
if worker is not None and worker.process.is_alive() and schedule_response.idle:
|
307 |
+
assert worker.arg_queue.empty()
|
308 |
+
assert worker.res_queue.empty()
|
309 |
+
else:
|
310 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
311 |
+
|
312 |
+
try:
|
313 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
314 |
+
except PicklingError: # TODO: detailed serialization diagnostic
|
315 |
+
release(fail=True)
|
316 |
+
raise
|
317 |
+
|
318 |
+
yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
|
319 |
+
def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
|
320 |
+
while True:
|
321 |
+
res = worker.res_queue.get()
|
322 |
+
if res is None:
|
323 |
+
release(fail=True, allow_404=True)
|
324 |
+
yield_queue.put(AbortedResult())
|
325 |
+
return
|
326 |
+
if isinstance(res, ExceptionResult):
|
327 |
+
release(fail=True)
|
328 |
+
yield_queue.put(ExceptionResult(res.value))
|
329 |
+
return
|
330 |
+
if isinstance(res, EndResult):
|
331 |
+
release()
|
332 |
+
workers[nvidia_index] = worker
|
333 |
+
yield_queue.put(EndResult())
|
334 |
+
return
|
335 |
+
if isinstance(res, OkResult):
|
336 |
+
yield_queue.put(OkResult(res.value))
|
337 |
+
continue
|
338 |
+
if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
|
339 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
340 |
+
continue
|
341 |
+
debug(f"fill_yield_queue: assert_never({res=})")
|
342 |
+
assert_never(res)
|
343 |
+
from typing_extensions import assert_never
|
344 |
+
with ThreadPoolExecutor() as e:
|
345 |
+
f = e.submit(copy_context().run, fill_yield_queue, worker)
|
346 |
+
f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
|
347 |
+
while True:
|
348 |
+
try:
|
349 |
+
res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
|
350 |
+
except Empty: # pragma: no cover
|
351 |
+
debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
|
352 |
+
raise
|
353 |
+
if isinstance(res, AbortedResult):
|
354 |
+
raise gr.Error("GPU task aborted")
|
355 |
+
if isinstance(res, ExceptionResult):
|
356 |
+
raise res.value
|
357 |
+
if isinstance(res, EndResult):
|
358 |
+
break
|
359 |
+
if isinstance(res, OkResult):
|
360 |
+
yield res.value
|
361 |
+
continue
|
362 |
+
debug(f"gradio_handler: assert_never({res=})")
|
363 |
+
assert_never(res)
|
364 |
+
|
365 |
+
|
366 |
+
def thread_wrapper(
|
367 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
368 |
+
res_queue: Queue[GeneratorResQueueResult[Res] | None],
|
369 |
+
allow_token: str,
|
370 |
+
nvidia_uuid: str,
|
371 |
+
fds: list[int],
|
372 |
+
):
|
373 |
+
global forked
|
374 |
+
forked = True
|
375 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
376 |
+
initialized = False
|
377 |
+
while True:
|
378 |
+
try:
|
379 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
380 |
+
except OSError:
|
381 |
+
break
|
382 |
+
if not initialized:
|
383 |
+
if (res := worker_init(
|
384 |
+
res_queue=res_queue,
|
385 |
+
allow_token=allow_token,
|
386 |
+
nvidia_uuid=nvidia_uuid,
|
387 |
+
fds=fds,
|
388 |
+
)) is not None:
|
389 |
+
res_queue.put(res)
|
390 |
+
return
|
391 |
+
initialized = True
|
392 |
+
def iterate():
|
393 |
+
gen = task(*args, **kwargs) # type: ignore
|
394 |
+
while True:
|
395 |
+
try:
|
396 |
+
res = next(gen)
|
397 |
+
except StopIteration:
|
398 |
+
break
|
399 |
+
except Exception as e:
|
400 |
+
res_queue.put(ExceptionResult(e))
|
401 |
+
break
|
402 |
+
try:
|
403 |
+
res_queue.put(OkResult(res))
|
404 |
+
except PicklingError as e:
|
405 |
+
res_queue.put(ExceptionResult(e))
|
406 |
+
break
|
407 |
+
else:
|
408 |
+
continue
|
409 |
+
GradioPartialContext.apply(gradio_context)
|
410 |
+
with ThreadPoolExecutor() as executor:
|
411 |
+
executor.submit(copy_context().run, iterate)
|
412 |
+
res_queue.put(EndResult())
|
413 |
+
|
414 |
+
# https://github.com/python/cpython/issues/91002
|
415 |
+
if not hasattr(task, '__annotations__'):
|
416 |
+
gradio_handler.__annotations__ = {}
|
417 |
+
|
418 |
+
return gradio_handler
|