import gc import numpy as np import numpy import torch from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel from matplotlib import pyplot as plt from pathlib import Path from PIL import Image from torch import autocast from torchvision import transforms as tfms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer, logging import os from diffusers import StableDiffusionPipeline, DiffusionPipeline # large or small model # configurations height, width = 512, 512 guidance_scale = 8 custom_loss_scale = 200 batch_size = 1 torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4" pipe = DiffusionPipeline.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch.float32 ).to(torch_device) # Load SD concepts sdconcepts = ['', '', '', '', ' '] pipe.load_textual_inversion("sd-concepts-library/morino-hon-style") pipe.load_textual_inversion("sd-concepts-library/space-style") pipe.load_textual_inversion("sd-concepts-library/tesla-bot") pipe.load_textual_inversion("sd-concepts-library/midjourney-style") pipe.load_textual_inversion("sd-concepts-library/hanfu-anime-style") # define seeds seed_list = [1, 2, 3, 4, 5] def custom_loss(images): # Gradient loss gradient_x = torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:]).mean() gradient_y = torch.abs(images[:, :, :-1, :] - images[:, :, 1:, :]).mean() error = gradient_x + gradient_y #Variational loss # diff_x = torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:]) # diff_y = torch.abs(images[:, :, :-1, :] - images[:, :, 1:, :]) # error = diff_x.mean() + diff_y.mean() return error def latents_to_pil(latents): # bath of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = pipe.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1 image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def generate_latents(prompts, num_inference_steps, seed_nums, loss_apply=False): generator = torch.manual_seed(seed_nums) # scheduler scheduler = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000) scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(torch.float32) # text embeddings of the prompt text_input = pipe.tokenizer(prompts, padding='max_length', max_length = pipe.tokenizer.model_max_length, truncation= True, return_tensors="pt") input_ids = text_input.input_ids.to(torch_device) with torch.no_grad(): text_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = pipe.tokenizer( [""] * batch_size, padding="max_length", max_length= max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) # 2,77,768 # random latent latents = torch.randn( (batch_size, pipe.unet.config.in_channels, height// 8, width //8), generator = generator, ) .to(torch.float16) latents = latents.to(torch_device) latents = latents * scheduler.init_noise_sigma for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): noise_pred = pipe.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"] #noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if (loss_apply and i%5 == 0): latents = latents.detach().requires_grad_() #latents_x0 = scheduler.step(noise_pred,t, latents).pred_original_sample # this line does not work latents_x0 = latents - sigma * noise_pred # use vae to decode the image denoised_images = pipe.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1) loss = custom_loss(denoised_images) * custom_loss_scale print(f"Custom gradient loss {loss}") cond_grad = torch.autograd.grad(loss, latents)[0] latents = latents.detach() - cond_grad * sigma**2 latents = scheduler.step(noise_pred,t, latents).prev_sample return latents # Function to convert PIL images to NumPy arrays def pil_to_np(image): return np.array(image) def generate_gradio_images(prompt, num_inference_steps, loss_flag = False): # after loss is applied latents_list = [] for seed_no, sd in zip(seed_list, sdconcepts): prompts = [f'{prompt} {sd}'] latents = generate_latents(prompts,num_inference_steps, seed_no, loss_apply=loss_flag) latents_list.append(latents) # show all latents_list = torch.vstack(latents_list) images = latents_to_pil(latents_list) return images