patrickvonplaten commited on
Commit
7302472
1 Parent(s): 459a0bd

[Safety Checker] Add Safety Checker Module

Browse files

Former-commit-id: d0c714ae4afa1c011269a956d6f260f84f77025e

Files changed (1) hide show
  1. scripts/txt2img.py +24 -1
scripts/txt2img.py CHANGED
@@ -16,12 +16,29 @@ from ldm.util import instantiate_from_config
16
  from ldm.models.diffusion.ddim import DDIMSampler
17
  from ldm.models.diffusion.plms import PLMSSampler
18
 
 
 
 
 
 
19
 
20
  def chunk(it, size):
21
  it = iter(it)
22
  return iter(lambda: tuple(islice(it, size)), ())
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def load_model_from_config(config, ckpt, verbose=False):
26
  print(f"Loading model from {ckpt}")
27
  pl_sd = torch.load(ckpt, map_location="cpu")
@@ -220,7 +237,9 @@ def main():
220
  if opt.fixed_code:
221
  start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
222
 
 
223
  precision_scope = autocast if opt.precision=="autocast" else nullcontext
 
224
  with torch.no_grad():
225
  with precision_scope("cuda"):
226
  with model.ema_scope():
@@ -269,7 +288,11 @@ def main():
269
  Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
270
  grid_count += 1
271
 
272
- toc = time.time()
 
 
 
 
273
 
274
  print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
275
  f" \nEnjoy.")
 
16
  from ldm.models.diffusion.ddim import DDIMSampler
17
  from ldm.models.diffusion.plms import PLMSSampler
18
 
19
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
20
+ from transformers import AutoFeatureExtractor
21
+
22
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
23
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v-1-3", use_auth_token=True)
24
 
25
  def chunk(it, size):
26
  it = iter(it)
27
  return iter(lambda: tuple(islice(it, size)), ())
28
 
29
 
30
+ def numpy_to_pil(images):
31
+ """
32
+ Convert a numpy image or a batch of images to a PIL image.
33
+ """
34
+ if images.ndim == 3:
35
+ images = images[None, ...]
36
+ images = (images * 255).round().astype("uint8")
37
+ pil_images = [Image.fromarray(image) for image in images]
38
+
39
+ return pil_images
40
+
41
+
42
  def load_model_from_config(config, ckpt, verbose=False):
43
  print(f"Loading model from {ckpt}")
44
  pl_sd = torch.load(ckpt, map_location="cpu")
 
237
  if opt.fixed_code:
238
  start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
239
 
240
+ print("start code", start_code.abs().sum())
241
  precision_scope = autocast if opt.precision=="autocast" else nullcontext
242
+ precision_scope = nullcontext
243
  with torch.no_grad():
244
  with precision_scope("cuda"):
245
  with model.ema_scope():
 
288
  Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
289
  grid_count += 1
290
 
291
+ image = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
292
+
293
+ # run safety checker
294
+ safety_checker_input = pipe.feature_extractor(numpy_to_pil(image), return_tensors="pt")
295
+ image, has_nsfw_concept = pipe.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
296
 
297
  print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
298
  f" \nEnjoy.")