Spaces:
Running
on
A10G
Running
on
A10G
Dimitri
commited on
Commit
•
92cf4eb
1
Parent(s):
61d0d14
fix demo
Browse files
app.py
CHANGED
@@ -10,13 +10,33 @@ from fabric.generator import AttentionBasedGenerator
|
|
10 |
model_name = ""
|
11 |
model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
css = """
|
@@ -96,33 +116,44 @@ def generate_fn(
|
|
96 |
liked = []
|
97 |
disliked = disliked[-max_feedback_imgs:]
|
98 |
# else: keep all feedback images
|
99 |
-
|
100 |
-
|
101 |
-
prompt
|
102 |
-
negative_prompt
|
103 |
-
liked
|
104 |
-
disliked
|
105 |
-
denoising_steps
|
106 |
-
guidance_scale
|
107 |
-
feedback_start
|
108 |
-
feedback_end
|
109 |
-
min_weight
|
110 |
-
max_weight
|
111 |
-
neg_scale
|
112 |
-
seed
|
113 |
-
n_images
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
|
116 |
except Exception as err:
|
117 |
raise gr.Error(str(err))
|
118 |
|
119 |
|
120 |
def add_img_from_list(i, curr_imgs, all_imgs):
|
|
|
|
|
121 |
if i >= 0 and i < len(curr_imgs):
|
122 |
all_imgs.append(curr_imgs[i])
|
123 |
return all_imgs, all_imgs # return (gallery, state)
|
124 |
|
125 |
def add_img(img, all_imgs):
|
|
|
|
|
126 |
all_imgs.append(img)
|
127 |
return None, all_imgs, all_imgs
|
128 |
|
@@ -148,7 +179,7 @@ with gr.Blocks(css=css) as demo:
|
|
148 |
with gr.Column():
|
149 |
denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
|
150 |
guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
|
151 |
-
batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size")
|
152 |
seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
|
153 |
max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
|
154 |
feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
|
@@ -222,8 +253,8 @@ with gr.Blocks(css=css) as demo:
|
|
222 |
liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
|
223 |
disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
|
224 |
|
225 |
-
clear_liked_btn.click(lambda: [
|
226 |
-
clear_disliked_btn.click(lambda: [
|
227 |
|
228 |
-
demo.queue(
|
229 |
-
demo.launch()
|
|
|
10 |
model_name = ""
|
11 |
model_ckpt = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_7_pruned.safetensors"
|
12 |
|
13 |
+
class GeneratorWrapper:
|
14 |
+
def __init__(self, model_name=None, model_ckpt=None):
|
15 |
+
self.model_name = model_name if model_name else None
|
16 |
+
self.model_ckpt = model_ckpt if model_ckpt else None
|
17 |
+
self.dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
self.reload()
|
21 |
+
|
22 |
+
def generate(self, *args, **kwargs):
|
23 |
+
return self.generator.generate(*args, **kwargs)
|
24 |
+
|
25 |
+
def to(self, device):
|
26 |
+
return self.generator.to(device)
|
27 |
+
|
28 |
+
def reload(self):
|
29 |
+
if hasattr(self, "generator"):
|
30 |
+
del self.generator
|
31 |
+
if self.device == "cuda":
|
32 |
+
torch.cuda.empty_cache()
|
33 |
+
self.generator = AttentionBasedGenerator(
|
34 |
+
model_name=self.model_name,
|
35 |
+
model_ckpt=self.model_ckpt,
|
36 |
+
torch_dtype=self.dtype,
|
37 |
+
).to(self.device)
|
38 |
+
|
39 |
+
generator = GeneratorWrapper(model_name, model_ckpt)
|
40 |
|
41 |
|
42 |
css = """
|
|
|
116 |
liked = []
|
117 |
disliked = disliked[-max_feedback_imgs:]
|
118 |
# else: keep all feedback images
|
119 |
+
|
120 |
+
generate_kwargs = {
|
121 |
+
"prompt": prompt,
|
122 |
+
"negative_prompt": neg_prompt,
|
123 |
+
"liked": liked,
|
124 |
+
"disliked": disliked,
|
125 |
+
"denoising_steps": denoising_steps,
|
126 |
+
"guidance_scale": guidance_scale,
|
127 |
+
"feedback_start": feedback_start,
|
128 |
+
"feedback_end": feedback_end,
|
129 |
+
"min_weight": min_weight,
|
130 |
+
"max_weight": max_weight,
|
131 |
+
"neg_scale": neg_scale,
|
132 |
+
"seed": seed,
|
133 |
+
"n_images": batch_size,
|
134 |
+
}
|
135 |
+
|
136 |
+
try:
|
137 |
+
images = generator.generate(**generate_kwargs)
|
138 |
+
except RuntimeError as err:
|
139 |
+
if 'out of memory' in str(err):
|
140 |
+
generator.reload()
|
141 |
+
raise
|
142 |
return [(img, f"Image {i+1}") for i, img in enumerate(images)], images
|
143 |
except Exception as err:
|
144 |
raise gr.Error(str(err))
|
145 |
|
146 |
|
147 |
def add_img_from_list(i, curr_imgs, all_imgs):
|
148 |
+
if all_imgs is None:
|
149 |
+
all_imgs = []
|
150 |
if i >= 0 and i < len(curr_imgs):
|
151 |
all_imgs.append(curr_imgs[i])
|
152 |
return all_imgs, all_imgs # return (gallery, state)
|
153 |
|
154 |
def add_img(img, all_imgs):
|
155 |
+
if all_imgs is None:
|
156 |
+
all_imgs = []
|
157 |
all_imgs.append(img)
|
158 |
return None, all_imgs, all_imgs
|
159 |
|
|
|
179 |
with gr.Column():
|
180 |
denoising_steps = gr.Slider(1, 100, value=20, step=1, label="Sampling steps")
|
181 |
guidance_scale = gr.Slider(0.0, 30.0, value=6, step=0.25, label="CFG scale")
|
182 |
+
batch_size = gr.Slider(1, 10, value=4, step=1, label="Batch size", interactive=False)
|
183 |
seed = gr.Number(-1, minimum=-1, precision=0, label="Seed")
|
184 |
max_feedback_imgs = gr.Slider(0, 20, value=6, step=1, label="Max. feedback images", info="Maximum number of liked/disliked images to be used. If exceeded, only the most recent images will be used as feedback. (NOTE: large number of feedback imgs => high VRAM requirements)")
|
185 |
feedback_enabled = gr.Checkbox(True, label="Enable feedback", interactive=True)
|
|
|
253 |
liked_img_input.upload(add_img, [liked_img_input, liked_imgs], [liked_img_input, like_gallery, liked_imgs], queue=False)
|
254 |
disliked_img_input.upload(add_img, [disliked_img_input, disliked_imgs], [disliked_img_input, dislike_gallery, disliked_imgs], queue=False)
|
255 |
|
256 |
+
clear_liked_btn.click(lambda: [[], []], None, [liked_imgs, like_gallery], queue=False)
|
257 |
+
clear_disliked_btn.click(lambda: [[], []], None, [disliked_imgs, dislike_gallery], queue=False)
|
258 |
|
259 |
+
demo.queue(1)
|
260 |
+
demo.launch(debug=True)
|