Spaces:
Running
on
Zero
Running
on
Zero
adamelliotfields
commited on
Async generate wrapper
Browse files- app.py +9 -9
- cli.py +7 -5
- lib/__init__.py +2 -2
- lib/config.py +1 -1
- lib/inference.py +35 -7
- lib/loader.py +10 -7
- requirements.txt +1 -0
app.py
CHANGED
@@ -4,7 +4,7 @@ import random
|
|
4 |
|
5 |
import gradio as gr
|
6 |
|
7 |
-
from lib import Config, generate
|
8 |
|
9 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
10 |
refresh_seed_js = """
|
@@ -79,7 +79,7 @@ def image_select_fn(images, image, i):
|
|
79 |
return gr.Image(images[i][0]) if i > -1 else None
|
80 |
|
81 |
|
82 |
-
def generate_fn(*args):
|
83 |
if len(args) > 0:
|
84 |
prompt = args[0]
|
85 |
else:
|
@@ -87,7 +87,7 @@ def generate_fn(*args):
|
|
87 |
if prompt is None or prompt.strip() == "":
|
88 |
raise gr.Error("You must enter a prompt")
|
89 |
try:
|
90 |
-
images = generate
|
91 |
except RuntimeError:
|
92 |
raise gr.Error("RuntimeError: Please try again")
|
93 |
return images
|
@@ -194,25 +194,25 @@ with gr.Blocks(
|
|
194 |
width = gr.Slider(
|
195 |
value=Config.WIDTH,
|
196 |
label="Width",
|
197 |
-
minimum=
|
198 |
maximum=768,
|
199 |
-
step=
|
200 |
)
|
201 |
height = gr.Slider(
|
202 |
value=Config.HEIGHT,
|
203 |
label="Height",
|
204 |
-
minimum=
|
205 |
maximum=768,
|
206 |
-
step=
|
207 |
)
|
208 |
aspect_ratio = gr.Dropdown(
|
209 |
choices=[
|
210 |
("Custom", None),
|
|
|
211 |
("7:9 (448x576)", "448,576"),
|
212 |
-
("3:4 (432x576)", "432,576"),
|
213 |
("1:1 (512x512)", "512,512"),
|
214 |
-
("4:3 (576x432)", "576,432"),
|
215 |
("9:7 (576x448)", "576,448"),
|
|
|
216 |
],
|
217 |
value="448,576",
|
218 |
filterable=False,
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
|
7 |
+
from lib import Config, async_call, generate
|
8 |
|
9 |
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
|
10 |
refresh_seed_js = """
|
|
|
79 |
return gr.Image(images[i][0]) if i > -1 else None
|
80 |
|
81 |
|
82 |
+
async def generate_fn(*args):
|
83 |
if len(args) > 0:
|
84 |
prompt = args[0]
|
85 |
else:
|
|
|
87 |
if prompt is None or prompt.strip() == "":
|
88 |
raise gr.Error("You must enter a prompt")
|
89 |
try:
|
90 |
+
images = await async_call(generate, *args, Info=gr.Info, Error=gr.Error)
|
91 |
except RuntimeError:
|
92 |
raise gr.Error("RuntimeError: Please try again")
|
93 |
return images
|
|
|
194 |
width = gr.Slider(
|
195 |
value=Config.WIDTH,
|
196 |
label="Width",
|
197 |
+
minimum=256,
|
198 |
maximum=768,
|
199 |
+
step=32,
|
200 |
)
|
201 |
height = gr.Slider(
|
202 |
value=Config.HEIGHT,
|
203 |
label="Height",
|
204 |
+
minimum=256,
|
205 |
maximum=768,
|
206 |
+
step=32,
|
207 |
)
|
208 |
aspect_ratio = gr.Dropdown(
|
209 |
choices=[
|
210 |
("Custom", None),
|
211 |
+
("4:7 (384x672)", "384,672"),
|
212 |
("7:9 (448x576)", "448,576"),
|
|
|
213 |
("1:1 (512x512)", "512,512"),
|
|
|
214 |
("9:7 (576x448)", "576,448"),
|
215 |
+
("7:4 (672x384)", "672,384"),
|
216 |
],
|
217 |
value="448,576",
|
218 |
filterable=False,
|
cli.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
# CLI
|
2 |
# usage: python cli.py 'colorful calico cat artstation'
|
3 |
import argparse
|
|
|
4 |
|
5 |
-
from lib import Config, generate
|
6 |
|
7 |
|
8 |
def save_images(images, filename="image.png"):
|
@@ -11,7 +12,7 @@ def save_images(images, filename="image.png"):
|
|
11 |
img.save(f"{name}.{ext}" if len(images) == 1 else f"{name}_{i}.{ext}")
|
12 |
|
13 |
|
14 |
-
def main():
|
15 |
# fmt: off
|
16 |
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
17 |
parser.add_argument("prompt", type=str, metavar="PROMPT")
|
@@ -42,7 +43,8 @@ def main():
|
|
42 |
# fmt: on
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
-
images =
|
|
|
46 |
args.prompt,
|
47 |
args.negative,
|
48 |
args.image,
|
@@ -68,8 +70,8 @@ def main():
|
|
68 |
args.deepcache,
|
69 |
args.scale,
|
70 |
)
|
71 |
-
save_images
|
72 |
|
73 |
|
74 |
if __name__ == "__main__":
|
75 |
-
main()
|
|
|
1 |
# CLI
|
2 |
# usage: python cli.py 'colorful calico cat artstation'
|
3 |
import argparse
|
4 |
+
import asyncio
|
5 |
|
6 |
+
from lib import Config, async_call, generate
|
7 |
|
8 |
|
9 |
def save_images(images, filename="image.png"):
|
|
|
12 |
img.save(f"{name}.{ext}" if len(images) == 1 else f"{name}_{i}.{ext}")
|
13 |
|
14 |
|
15 |
+
async def main():
|
16 |
# fmt: off
|
17 |
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
18 |
parser.add_argument("prompt", type=str, metavar="PROMPT")
|
|
|
43 |
# fmt: on
|
44 |
|
45 |
args = parser.parse_args()
|
46 |
+
images = await async_call(
|
47 |
+
generate,
|
48 |
args.prompt,
|
49 |
args.negative,
|
50 |
args.image,
|
|
|
70 |
args.deepcache,
|
71 |
args.scale,
|
72 |
)
|
73 |
+
await async_call(save_images, images, args.filename)
|
74 |
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
+
asyncio.run(main())
|
lib/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from .config import Config
|
2 |
-
from .inference import generate
|
3 |
from .loader import Loader
|
4 |
from .upscaler import RealESRGAN
|
5 |
|
6 |
-
__all__ = ["Config", "Loader", "RealESRGAN", "generate"]
|
|
|
1 |
from .config import Config
|
2 |
+
from .inference import async_call, generate
|
3 |
from .loader import Loader
|
4 |
from .upscaler import RealESRGAN
|
5 |
|
6 |
+
__all__ = ["Config", "Loader", "RealESRGAN", "async_call", "generate"]
|
lib/config.py
CHANGED
@@ -41,7 +41,7 @@ Config = SimpleNamespace(
|
|
41 |
GUIDANCE_SCALE=6,
|
42 |
INFERENCE_STEPS=35,
|
43 |
DENOISING_STRENGTH=0.6,
|
44 |
-
DEEPCACHE_INTERVAL=
|
45 |
SCALE=1,
|
46 |
SCALES=[1, 2, 4],
|
47 |
)
|
|
|
41 |
GUIDANCE_SCALE=6,
|
42 |
INFERENCE_STEPS=35,
|
43 |
DENOISING_STRENGTH=0.6,
|
44 |
+
DEEPCACHE_INTERVAL=1,
|
45 |
SCALE=1,
|
46 |
SCALES=[1, 2, 4],
|
47 |
)
|
lib/inference.py
CHANGED
@@ -1,26 +1,48 @@
|
|
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
import re
|
4 |
import time
|
5 |
from datetime import datetime
|
6 |
from itertools import product
|
7 |
-
from typing import Callable
|
8 |
|
|
|
9 |
import numpy as np
|
10 |
import spaces
|
11 |
import torch
|
|
|
12 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
13 |
from compel.prompt_parser import PromptParser
|
14 |
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
|
15 |
from PIL import Image
|
|
|
16 |
|
17 |
from .loader import Loader
|
18 |
|
19 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
20 |
__import__("transformers").logging.set_verbosity_error()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
with open("./data/styles.json") as f:
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
# parse prompts with arrays
|
@@ -43,10 +65,10 @@ def parse_prompt(prompt: str) -> list[str]:
|
|
43 |
|
44 |
|
45 |
def apply_style(prompt, style_id, negative=False):
|
46 |
-
global
|
47 |
if not style_id or style_id == "None":
|
48 |
return prompt
|
49 |
-
for style in
|
50 |
if style["id"] == style_id:
|
51 |
if negative:
|
52 |
return prompt + " . " + style["negative_prompt"]
|
@@ -55,7 +77,7 @@ def apply_style(prompt, style_id, negative=False):
|
|
55 |
return prompt
|
56 |
|
57 |
|
58 |
-
def prepare_image(input, size=
|
59 |
image = None
|
60 |
if isinstance(input, Image.Image):
|
61 |
image = input
|
@@ -65,7 +87,11 @@ def prepare_image(input, size=(512, 512)):
|
|
65 |
if os.path.isfile(input):
|
66 |
image = Image.open(input)
|
67 |
if image is not None:
|
68 |
-
|
|
|
|
|
|
|
|
|
69 |
else:
|
70 |
raise ValueError("Invalid image prompt")
|
71 |
|
@@ -213,7 +239,9 @@ def generate(
|
|
213 |
kwargs["image"] = prepare_image(image_prompt, (width, height))
|
214 |
|
215 |
if IP_ADAPTER:
|
216 |
-
|
|
|
|
|
217 |
|
218 |
try:
|
219 |
image = pipe(**kwargs).images[0]
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
import json
|
4 |
import os
|
5 |
import re
|
6 |
import time
|
7 |
from datetime import datetime
|
8 |
from itertools import product
|
9 |
+
from typing import Callable, TypeVar
|
10 |
|
11 |
+
import anyio
|
12 |
import numpy as np
|
13 |
import spaces
|
14 |
import torch
|
15 |
+
from anyio import Semaphore
|
16 |
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
|
17 |
from compel.prompt_parser import PromptParser
|
18 |
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
|
19 |
from PIL import Image
|
20 |
+
from typing_extensions import ParamSpec
|
21 |
|
22 |
from .loader import Loader
|
23 |
|
24 |
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
|
25 |
__import__("transformers").logging.set_verbosity_error()
|
26 |
|
27 |
+
T = TypeVar("T")
|
28 |
+
P = ParamSpec("P")
|
29 |
+
|
30 |
+
MAX_CONCURRENT_THREADS = 1
|
31 |
+
MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
|
32 |
+
|
33 |
with open("./data/styles.json") as f:
|
34 |
+
STYLES = json.load(f)
|
35 |
+
|
36 |
+
|
37 |
+
# like the original but supports args and kwargs instead of a dict
|
38 |
+
# https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
|
39 |
+
async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
40 |
+
async with MAX_THREADS_GUARD:
|
41 |
+
sig = inspect.signature(fn)
|
42 |
+
bound_args = sig.bind(*args, **kwargs)
|
43 |
+
bound_args.apply_defaults()
|
44 |
+
partial_fn = functools.partial(fn, **bound_args.arguments)
|
45 |
+
return await anyio.to_thread.run_sync(partial_fn)
|
46 |
|
47 |
|
48 |
# parse prompts with arrays
|
|
|
65 |
|
66 |
|
67 |
def apply_style(prompt, style_id, negative=False):
|
68 |
+
global STYLES
|
69 |
if not style_id or style_id == "None":
|
70 |
return prompt
|
71 |
+
for style in STYLES:
|
72 |
if style["id"] == style_id:
|
73 |
if negative:
|
74 |
return prompt + " . " + style["negative_prompt"]
|
|
|
77 |
return prompt
|
78 |
|
79 |
|
80 |
+
def prepare_image(input, size=None):
|
81 |
image = None
|
82 |
if isinstance(input, Image.Image):
|
83 |
image = input
|
|
|
87 |
if os.path.isfile(input):
|
88 |
image = Image.open(input)
|
89 |
if image is not None:
|
90 |
+
image = image.convert("RGB")
|
91 |
+
if size is not None:
|
92 |
+
image = image.resize(size, Image.Resampling.LANCZOS)
|
93 |
+
if image is not None:
|
94 |
+
return image
|
95 |
else:
|
96 |
raise ValueError("Invalid image prompt")
|
97 |
|
|
|
239 |
kwargs["image"] = prepare_image(image_prompt, (width, height))
|
240 |
|
241 |
if IP_ADAPTER:
|
242 |
+
# don't resize full-face images
|
243 |
+
size = None if ip_face else (width, height)
|
244 |
+
kwargs["ip_adapter_image"] = prepare_image(ip_image, size)
|
245 |
|
246 |
try:
|
247 |
image = pipe(**kwargs).images[0]
|
lib/loader.py
CHANGED
@@ -104,31 +104,33 @@ class Loader:
|
|
104 |
print("Switching to Tiny VAE...")
|
105 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
106 |
pretrained_model_name_or_path="madebyollin/taesd",
|
107 |
-
|
|
|
108 |
return
|
109 |
|
110 |
if is_tiny and not taesd:
|
111 |
print("Switching to KL VAE...")
|
112 |
model = AutoencoderKL.from_pretrained(
|
113 |
pretrained_model_name_or_path=model_name,
|
|
|
114 |
subfolder="vae",
|
115 |
variant=variant,
|
116 |
-
).to(self.pipe.device
|
117 |
self.pipe.vae = torch.compile(
|
118 |
mode="reduce-overhead",
|
119 |
fullgraph=True,
|
120 |
model=model,
|
121 |
)
|
122 |
|
123 |
-
def _load_pipeline(self, kind, model, device,
|
124 |
pipelines = {
|
125 |
"txt2img": StableDiffusionPipeline,
|
126 |
"img2img": StableDiffusionImg2ImgPipeline,
|
127 |
}
|
128 |
if self.pipe is None:
|
129 |
-
self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device
|
130 |
if not isinstance(self.pipe, pipelines[kind]):
|
131 |
-
self.pipe = pipelines[kind].from_pipe(self.pipe).to(device
|
132 |
self.ip_adapter = None
|
133 |
|
134 |
def load(
|
@@ -186,13 +188,14 @@ class Loader:
|
|
186 |
"scheduler": schedulers[scheduler](**scheduler_kwargs),
|
187 |
"requires_safety_checker": False,
|
188 |
"safety_checker": None,
|
|
|
189 |
"variant": variant,
|
190 |
}
|
191 |
|
192 |
if self.pipe is None:
|
193 |
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
|
194 |
|
195 |
-
self._load_pipeline(kind, model_lower, device,
|
196 |
model_name = self.pipe.config._name_or_path
|
197 |
same_model = model_name.lower() == model_lower
|
198 |
same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
|
@@ -210,7 +213,7 @@ class Loader:
|
|
210 |
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
|
211 |
else:
|
212 |
self.pipe = None
|
213 |
-
self._load_pipeline(kind, model_lower, device,
|
214 |
|
215 |
self._load_ip_adapter(ip_adapter)
|
216 |
self._load_vae(taesd, model_lower, variant)
|
|
|
104 |
print("Switching to Tiny VAE...")
|
105 |
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
106 |
pretrained_model_name_or_path="madebyollin/taesd",
|
107 |
+
torch_dtype=self.pipe.dtype,
|
108 |
+
).to(self.pipe.device)
|
109 |
return
|
110 |
|
111 |
if is_tiny and not taesd:
|
112 |
print("Switching to KL VAE...")
|
113 |
model = AutoencoderKL.from_pretrained(
|
114 |
pretrained_model_name_or_path=model_name,
|
115 |
+
torch_dtype=self.pipe.dtype,
|
116 |
subfolder="vae",
|
117 |
variant=variant,
|
118 |
+
).to(self.pipe.device)
|
119 |
self.pipe.vae = torch.compile(
|
120 |
mode="reduce-overhead",
|
121 |
fullgraph=True,
|
122 |
model=model,
|
123 |
)
|
124 |
|
125 |
+
def _load_pipeline(self, kind, model, device, **kwargs):
|
126 |
pipelines = {
|
127 |
"txt2img": StableDiffusionPipeline,
|
128 |
"img2img": StableDiffusionImg2ImgPipeline,
|
129 |
}
|
130 |
if self.pipe is None:
|
131 |
+
self.pipe = pipelines[kind].from_pretrained(model, **kwargs).to(device)
|
132 |
if not isinstance(self.pipe, pipelines[kind]):
|
133 |
+
self.pipe = pipelines[kind].from_pipe(self.pipe).to(device)
|
134 |
self.ip_adapter = None
|
135 |
|
136 |
def load(
|
|
|
188 |
"scheduler": schedulers[scheduler](**scheduler_kwargs),
|
189 |
"requires_safety_checker": False,
|
190 |
"safety_checker": None,
|
191 |
+
"torch_dtype": dtype,
|
192 |
"variant": variant,
|
193 |
}
|
194 |
|
195 |
if self.pipe is None:
|
196 |
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
|
197 |
|
198 |
+
self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
|
199 |
model_name = self.pipe.config._name_or_path
|
200 |
same_model = model_name.lower() == model_lower
|
201 |
same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
|
|
|
213 |
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
|
214 |
else:
|
215 |
self.pipe = None
|
216 |
+
self._load_pipeline(kind, model_lower, device, **pipe_kwargs)
|
217 |
|
218 |
self._load_ip_adapter(ip_adapter)
|
219 |
self._load_vae(taesd, model_lower, variant)
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
accelerate
|
2 |
einops==0.8.0
|
3 |
compel==2.0.3
|
|
|
1 |
+
anyio==4.4.0
|
2 |
accelerate
|
3 |
einops==0.8.0
|
4 |
compel==2.0.3
|