Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Commit
•
29c923f
1
Parent(s):
dec8492
Add utils
Browse files- lib/__init__.py +13 -2
- lib/utils.py +78 -0
- requirements.txt +2 -0
lib/__init__.py
CHANGED
@@ -1,6 +1,17 @@
|
|
1 |
from .config import Config
|
2 |
-
from .inference import
|
3 |
from .loader import Loader
|
4 |
from .upscaler import RealESRGAN
|
|
|
5 |
|
6 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from .config import Config
|
2 |
+
from .inference import generate
|
3 |
from .loader import Loader
|
4 |
from .upscaler import RealESRGAN
|
5 |
+
from .utils import async_call, download_civit_file, download_repo_files, load_json, read_file
|
6 |
|
7 |
+
__all__ = [
|
8 |
+
"Config",
|
9 |
+
"Loader",
|
10 |
+
"RealESRGAN",
|
11 |
+
"async_call",
|
12 |
+
"download_civit_file",
|
13 |
+
"download_repo_files",
|
14 |
+
"generate",
|
15 |
+
"load_json",
|
16 |
+
"read_file",
|
17 |
+
]
|
lib/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from typing import Callable, TypeVar
|
6 |
+
|
7 |
+
import anyio
|
8 |
+
import httpx
|
9 |
+
from anyio import Semaphore
|
10 |
+
from huggingface_hub._snapshot_download import snapshot_download
|
11 |
+
from typing_extensions import ParamSpec
|
12 |
+
|
13 |
+
T = TypeVar("T")
|
14 |
+
P = ParamSpec("P")
|
15 |
+
|
16 |
+
MAX_CONCURRENT_THREADS = 1
|
17 |
+
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
18 |
+
|
19 |
+
|
20 |
+
@functools.lru_cache()
|
21 |
+
def load_json(path: str) -> dict:
|
22 |
+
with open(path, "r", encoding="utf-8") as file:
|
23 |
+
return json.load(file)
|
24 |
+
|
25 |
+
|
26 |
+
@functools.lru_cache()
|
27 |
+
def read_file(path: str) -> str:
|
28 |
+
with open(path, "r", encoding="utf-8") as file:
|
29 |
+
return file.read()
|
30 |
+
|
31 |
+
|
32 |
+
def download_repo_files(repo_id, allow_patterns, token=None):
|
33 |
+
return snapshot_download(
|
34 |
+
repo_id=repo_id,
|
35 |
+
repo_type="model",
|
36 |
+
revision="main",
|
37 |
+
token=token,
|
38 |
+
allow_patterns=allow_patterns,
|
39 |
+
ignore_patterns=None,
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
def download_civit_file(lora_id, version_id, file_path=".", token=None):
|
44 |
+
base_url = "https://civitai.com/api/download/models"
|
45 |
+
file = f"{file_path}/{lora_id}.{version_id}.safetensors"
|
46 |
+
|
47 |
+
if os.path.exists(file):
|
48 |
+
return
|
49 |
+
|
50 |
+
try:
|
51 |
+
params = {"token": token}
|
52 |
+
response = httpx.get(
|
53 |
+
f"{base_url}/{version_id}",
|
54 |
+
timeout=None,
|
55 |
+
params=params,
|
56 |
+
follow_redirects=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
response.raise_for_status()
|
60 |
+
os.makedirs(file_path, exist_ok=True)
|
61 |
+
|
62 |
+
with open(file, "wb") as f:
|
63 |
+
f.write(response.content)
|
64 |
+
except httpx.HTTPStatusError as e:
|
65 |
+
print(f"HTTPError: {e.response.status_code} {e.response.text}")
|
66 |
+
except httpx.RequestError as e:
|
67 |
+
print(f"RequestError: {e}")
|
68 |
+
|
69 |
+
|
70 |
+
# like the original but supports args and kwargs instead of a dict
|
71 |
+
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
72 |
+
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
73 |
+
async with MAX_THREADS_GUARD:
|
74 |
+
sig = inspect.signature(fn)
|
75 |
+
bound_args = sig.bind(*args, **kwargs)
|
76 |
+
bound_args.apply_defaults()
|
77 |
+
partial_fn = functools.partial(fn, **bound_args.arguments)
|
78 |
+
return await anyio.to_thread.run_sync(partial_fn)
|
requirements.txt
CHANGED
@@ -4,7 +4,9 @@ einops==0.8.0
|
|
4 |
compel==2.0.3
|
5 |
deepcache==0.1.1
|
6 |
diffusers==0.30.2
|
|
|
7 |
hf-transfer
|
|
|
8 |
gradio==4.41.0
|
9 |
numpy==1.26.4
|
10 |
ruff==0.5.7
|
|
|
4 |
compel==2.0.3
|
5 |
deepcache==0.1.1
|
6 |
diffusers==0.30.2
|
7 |
+
h2
|
8 |
hf-transfer
|
9 |
+
httpx
|
10 |
gradio==4.41.0
|
11 |
numpy==1.26.4
|
12 |
ruff==0.5.7
|