Spaces:
Running
on
Zero
Running
on
Zero
add NSFW checker and GPU mode
Browse files- app.py +44 -28
- data/nsfw.jpg +0 -0
- utils/pipeline.py +15 -1
app.py
CHANGED
@@ -61,7 +61,13 @@ class GlobalText:
|
|
61 |
self.pipeline = None
|
62 |
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
63 |
self.lora_model_state_dict = {}
|
64 |
-
self.device = torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def init_source_image_path(self, source_path):
|
67 |
self.source_paths = sorted(glob(os.path.join(source_path, '*')))
|
@@ -83,9 +89,9 @@ class GlobalText:
|
|
83 |
|
84 |
self.scheduler = 'LCM'
|
85 |
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
86 |
-
self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,)
|
87 |
-
|
88 |
-
|
89 |
time_end = datetime.now()
|
90 |
print(f'Load {model_path} successful in {time_end-time_start}')
|
91 |
return gr.Dropdown()
|
@@ -171,7 +177,7 @@ class GlobalText:
|
|
171 |
de_bug=de_bug,)
|
172 |
|
173 |
time_begin = datetime.now()
|
174 |
-
|
175 |
negative_prompt=negative_prompt_textbox,
|
176 |
image=source,
|
177 |
style=style,
|
@@ -183,7 +189,16 @@ class GlobalText:
|
|
183 |
fix_step_index=co_feat_step,
|
184 |
de_bug = de_bug,
|
185 |
callback = None
|
186 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
time_end = datetime.now()
|
188 |
print('generate one image with time {}'.format(time_end-time_begin))
|
189 |
|
@@ -191,18 +206,19 @@ class GlobalText:
|
|
191 |
|
192 |
|
193 |
save_file_path = os.path.join(self.savedir, save_file_name)
|
194 |
-
|
195 |
save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0)
|
196 |
save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0)
|
197 |
self.init_results_image_path()
|
198 |
-
return [
|
199 |
-
generate_image[0],
|
200 |
-
generate_image[1],
|
201 |
-
generate_image[2],
|
202 |
-
self.init_results_image_path()
|
203 |
-
]
|
204 |
-
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
global_text = GlobalText()
|
207 |
|
208 |
|
@@ -309,23 +325,23 @@ def ui():
|
|
309 |
|
310 |
style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery])
|
311 |
|
312 |
-
with gr.Tab("Results Gallery"):
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
|
320 |
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
|
328 |
-
|
329 |
|
330 |
|
331 |
|
|
|
61 |
self.pipeline = None
|
62 |
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
63 |
self.lora_model_state_dict = {}
|
64 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
65 |
+
|
66 |
+
self.nsfw_image = Image.open('./data/nsfw.jpg') # to float in [0,1]
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
|
72 |
def init_source_image_path(self, source_path):
|
73 |
self.source_paths = sorted(glob(os.path.join(source_path, '*')))
|
|
|
89 |
|
90 |
self.scheduler = 'LCM'
|
91 |
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
92 |
+
self.pipeline = ZePoPipeline.from_pretrained(model_path,scheduler=scheduler,torch_dtype=torch.float16,).to('cuda')
|
93 |
+
if is_xformers:
|
94 |
+
self.pipeline.enable_xformers_memory_efficient_attention()
|
95 |
time_end = datetime.now()
|
96 |
print(f'Load {model_path} successful in {time_end-time_start}')
|
97 |
return gr.Dropdown()
|
|
|
177 |
de_bug=de_bug,)
|
178 |
|
179 |
time_begin = datetime.now()
|
180 |
+
results = model(prompt=prompts,
|
181 |
negative_prompt=negative_prompt_textbox,
|
182 |
image=source,
|
183 |
style=style,
|
|
|
189 |
fix_step_index=co_feat_step,
|
190 |
de_bug = de_bug,
|
191 |
callback = None
|
192 |
+
)
|
193 |
+
generate_image = results.images
|
194 |
+
|
195 |
+
|
196 |
+
for idx, has_nsfw_concept in enumerate(results.nsfw_content_detected):
|
197 |
+
if has_nsfw_concept:
|
198 |
+
generate_image[idx] = np.array(self.nsfw_image.resize((height_slider,width_slider))).astype(np.float32) / 255.0
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
time_end = datetime.now()
|
203 |
print('generate one image with time {}'.format(time_end-time_begin))
|
204 |
|
|
|
206 |
|
207 |
|
208 |
save_file_path = os.path.join(self.savedir, save_file_name)
|
209 |
+
|
210 |
save_image(torch.tensor(generate_image).permute(0, 3, 1, 2), save_file_path, nrow=3, padding=0)
|
211 |
save_image(torch.tensor(generate_image[2:]).permute(0, 3, 1, 2), os.path.join(self.savedir_sample, save_file_name), nrow=3, padding=0)
|
212 |
self.init_results_image_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
return [
|
215 |
+
generate_image[0],
|
216 |
+
generate_image[1],
|
217 |
+
generate_image[2],
|
218 |
+
self.init_results_image_path()
|
219 |
+
]
|
220 |
+
|
221 |
+
|
222 |
global_text = GlobalText()
|
223 |
|
224 |
|
|
|
325 |
|
326 |
style_gallery_index.change(fn=update_style_list, inputs=[style_gallery_index], outputs=[style_image_gallery])
|
327 |
|
328 |
+
# with gr.Tab("Results Gallery"):
|
329 |
+
# with gr.Row():
|
330 |
+
# refresh_results_list_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
331 |
+
# results_gallery_index = gr.Slider(label="Index", value=0, minimum=0, maximum=50, step=1)
|
332 |
+
# num_gallery_images = 12
|
333 |
+
# results_image_gallery = gr.Gallery(value=[], columns=4, label="style Image List")
|
334 |
+
# refresh_results_list_button.click(fn=global_text.init_results_image_path, inputs=[], outputs=[results_image_gallery])
|
335 |
|
336 |
|
337 |
+
# def update_results_list(index):
|
338 |
+
# if int(index) < 0:
|
339 |
+
# index = 0
|
340 |
+
# if int(index) > global_text.max_results_index:
|
341 |
+
# index = global_text.max_results_index
|
342 |
+
# return global_text.results_paths[int(index)*num_gallery_images:(int(index)+1)*num_gallery_images]
|
343 |
|
344 |
+
# results_gallery_index.change(fn=update_results_list, inputs=[results_gallery_index], outputs=[style_image_gallery])
|
345 |
|
346 |
|
347 |
|
data/nsfw.jpg
ADDED
utils/pipeline.py
CHANGED
@@ -157,6 +157,20 @@ class ZePoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMix
|
|
157 |
extra_step_kwargs["generator"] = generator
|
158 |
return extra_step_kwargs
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
162 |
def decode_latents(self, latents):
|
@@ -416,7 +430,7 @@ class ZePoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMix
|
|
416 |
# 9. Post-processing
|
417 |
if not output_type == "latent":
|
418 |
image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
|
419 |
-
has_nsfw_concept =
|
420 |
else:
|
421 |
image = pred_x0
|
422 |
has_nsfw_concept = None
|
|
|
157 |
extra_step_kwargs["generator"] = generator
|
158 |
return extra_step_kwargs
|
159 |
|
160 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
161 |
+
def run_safety_checker(self, image, device, dtype):
|
162 |
+
if self.safety_checker is None:
|
163 |
+
has_nsfw_concept = None
|
164 |
+
else:
|
165 |
+
if torch.is_tensor(image):
|
166 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
167 |
+
else:
|
168 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
169 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
170 |
+
image, has_nsfw_concept = self.safety_checker(
|
171 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
172 |
+
)
|
173 |
+
return image, has_nsfw_concept
|
174 |
|
175 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
176 |
def decode_latents(self, latents):
|
|
|
430 |
# 9. Post-processing
|
431 |
if not output_type == "latent":
|
432 |
image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
|
433 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
434 |
else:
|
435 |
image = pred_x0
|
436 |
has_nsfw_concept = None
|