adamelliotfields commited on
Commit
29c923f
1 Parent(s): dec8492
Files changed (3) hide show
  1. lib/__init__.py +13 -2
  2. lib/utils.py +78 -0
  3. requirements.txt +2 -0
lib/__init__.py CHANGED
@@ -1,6 +1,17 @@
1
  from .config import Config
2
- from .inference import async_call, generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
 
5
 
6
- __all__ = ["Config", "Loader", "RealESRGAN", "async_call", "generate"]
 
 
 
 
 
 
 
 
 
 
 
1
  from .config import Config
2
+ from .inference import generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
+ from .utils import async_call, download_civit_file, download_repo_files, load_json, read_file
6
 
7
+ __all__ = [
8
+ "Config",
9
+ "Loader",
10
+ "RealESRGAN",
11
+ "async_call",
12
+ "download_civit_file",
13
+ "download_repo_files",
14
+ "generate",
15
+ "load_json",
16
+ "read_file",
17
+ ]
lib/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import json
4
+ import os
5
+ from typing import Callable, TypeVar
6
+
7
+ import anyio
8
+ import httpx
9
+ from anyio import Semaphore
10
+ from huggingface_hub._snapshot_download import snapshot_download
11
+ from typing_extensions import ParamSpec
12
+
13
+ T = TypeVar("T")
14
+ P = ParamSpec("P")
15
+
16
+ MAX_CONCURRENT_THREADS = 1
17
+ MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
18
+
19
+
20
+ @functools.lru_cache()
21
+ def load_json(path: str) -> dict:
22
+ with open(path, "r", encoding="utf-8") as file:
23
+ return json.load(file)
24
+
25
+
26
+ @functools.lru_cache()
27
+ def read_file(path: str) -> str:
28
+ with open(path, "r", encoding="utf-8") as file:
29
+ return file.read()
30
+
31
+
32
+ def download_repo_files(repo_id, allow_patterns, token=None):
33
+ return snapshot_download(
34
+ repo_id=repo_id,
35
+ repo_type="model",
36
+ revision="main",
37
+ token=token,
38
+ allow_patterns=allow_patterns,
39
+ ignore_patterns=None,
40
+ )
41
+
42
+
43
+ def download_civit_file(lora_id, version_id, file_path=".", token=None):
44
+ base_url = "https://civitai.com/api/download/models"
45
+ file = f"{file_path}/{lora_id}.{version_id}.safetensors"
46
+
47
+ if os.path.exists(file):
48
+ return
49
+
50
+ try:
51
+ params = {"token": token}
52
+ response = httpx.get(
53
+ f"{base_url}/{version_id}",
54
+ timeout=None,
55
+ params=params,
56
+ follow_redirects=True,
57
+ )
58
+
59
+ response.raise_for_status()
60
+ os.makedirs(file_path, exist_ok=True)
61
+
62
+ with open(file, "wb") as f:
63
+ f.write(response.content)
64
+ except httpx.HTTPStatusError as e:
65
+ print(f"HTTPError: {e.response.status_code} {e.response.text}")
66
+ except httpx.RequestError as e:
67
+ print(f"RequestError: {e}")
68
+
69
+
70
+ # like the original but supports args and kwargs instead of a dict
71
+ # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
72
+ async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
73
+ async with MAX_THREADS_GUARD:
74
+ sig = inspect.signature(fn)
75
+ bound_args = sig.bind(*args, **kwargs)
76
+ bound_args.apply_defaults()
77
+ partial_fn = functools.partial(fn, **bound_args.arguments)
78
+ return await anyio.to_thread.run_sync(partial_fn)
requirements.txt CHANGED
@@ -4,7 +4,9 @@ einops==0.8.0
4
  compel==2.0.3
5
  deepcache==0.1.1
6
  diffusers==0.30.2
 
7
  hf-transfer
 
8
  gradio==4.41.0
9
  numpy==1.26.4
10
  ruff==0.5.7
 
4
  compel==2.0.3
5
  deepcache==0.1.1
6
  diffusers==0.30.2
7
+ h2
8
  hf-transfer
9
+ httpx
10
  gradio==4.41.0
11
  numpy==1.26.4
12
  ruff==0.5.7