import math import os from typing import List, Optional, Union import numpy as np import torch from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig from PIL import Image from torch import autocast from sgm.util import append_dims class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [0, 1] Returns: same as input but watermarked """ squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] return image # A fixed 48-bit message that was choosen at random # WATERMARK_MESSAGE = 0xB3EC907BB19E WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS) def get_unique_embedder_keys_from_conditioner(conditioner): return list({x.input_key for x in conditioner.embedders}) def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas params: strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) """ def __init__(self, discretization, strength: float = 1.0): self.discretization = discretization self.strength = strength assert 0.0 <= self.strength <= 1.0 def __call__(self, *args, **kwargs): # sigmas start large first, and decrease then sigmas = self.discretization(*args, **kwargs) print(f"sigmas after discretization, before pruning img2img: ", sigmas) sigmas = torch.flip(sigmas, (0,)) sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] print("prune index:", max(int(self.strength * len(sigmas)), 1)) sigmas = torch.flip(sigmas, (0,)) print(f"sigmas after pruning: ", sigmas) return sigmas def do_sample( model, sampler, value_dict, num_samples, H, W, C, F, force_uc_zero_embeddings: Optional[List] = None, batch2model_input: Optional[List] = None, return_latents=False, filter=None, device="cuda", ): if force_uc_zero_embeddings is None: force_uc_zero_embeddings = [] if batch2model_input is None: batch2model_input = [] with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples, ) for key in batch: if isinstance(batch[key], torch.Tensor): print(key, batch[key].shape) elif isinstance(batch[key], list): print(key, [len(l) for l in batch[key]]) else: print(key, batch[key]) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: if not k == "crossattn": c[k], uc[k] = map( lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) ) additional_model_inputs = {} for k in batch2model_input: additional_model_inputs[k] = batch[k] shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) samples_z = sampler(denoiser, randn, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): # Hardcoded demo setups; might undergo some changes in the future batch = {} batch_uc = {} for key in keys: if key == "txt": batch["txt"] = ( np.repeat([value_dict["prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) batch_uc["txt"] = ( np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) .reshape(N) .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) .to(device) .repeat(*N, 1) ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( torch.tensor( [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] ) .to(device) .repeat(*N, 1) ) elif key == "aesthetic_score": batch["aesthetic_score"] = ( torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( torch.tensor([value_dict["negative_aesthetic_score"]]) .to(device) .repeat(*N, 1) ) elif key == "target_size_as_tuple": batch["target_size_as_tuple"] = ( torch.tensor([value_dict["target_height"], value_dict["target_width"]]) .to(device) .repeat(*N, 1) ) else: batch[key] = value_dict[key] for key in batch.keys(): if key not in batch_uc and isinstance(batch[key], torch.Tensor): batch_uc[key] = torch.clone(batch[key]) return batch, batch_uc def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") width, height = map( lambda x: x - x % 64, (w, h) ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 return image_tensor.to(device) def do_img2img( img, model, sampler, value_dict, num_samples, force_uc_zero_embeddings=[], additional_kwargs={}, offset_noise_level: float = 0.0, return_latents=False, skip_encode=False, filter=None, device="cuda", ): with torch.no_grad(): with autocast(device) as precision_scope: with model.ema_scope(): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, [num_samples], ) c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, ) for k in c: c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) for k in additional_kwargs: c[k] = uc[k] = additional_kwargs[k] if skip_encode: z = img else: z = model.encode_first_stage(img) noise = torch.randn_like(z) sigmas = sampler.discretization(sampler.num_steps) sigma = sigmas[0].to(z.device) if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( torch.randn(z.shape[0], device=z.device), z.ndim ) noised_z = z + noise * append_dims(sigma, z.ndim) noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 ) # Note: hardcoded to DDPM-like scaling. need to generalize later. def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z return samples