Spaces:
Sleeping
Sleeping
import torch | |
from tqdm.auto import tqdm | |
from diffusers import LMSDiscreteScheduler | |
import config | |
def construct_text_embeddings(pipe, prompt): | |
text_input = pipe.tokenizer(prompt, padding='max_length', | |
max_length = pipe.tokenizer.model_max_length, truncation= True, | |
return_tensors="pt") | |
uncond_input = pipe.tokenizer([""] * config.BATCH_SIZE, padding="max_length", | |
max_length= text_input.input_ids.shape[-1], | |
return_tensors="pt") | |
with torch.no_grad(): | |
text_input_embeddings = pipe.text_encoder(text_input.input_ids.to(config.DEVICE))[0] | |
with torch.no_grad(): | |
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(config.DEVICE))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_input_embeddings]) | |
return text_embeddings | |
def initialize_latent(seed_number, pipe, scheduler): | |
generator = torch.manual_seed(seed_number) | |
latent = torch.randn((config.BATCH_SIZE, pipe.unet.config.in_channels, | |
config.HEIGHT//8, config.WIDTH//8), | |
generator = generator).to(torch.float16) | |
latent = latent.to(config.DEVICE) | |
latent = latent * scheduler.init_noise_sigma | |
return latent | |
def run_prediction(pipe, text_embeddings, scheduler, latent, loss_function=None): | |
for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)): | |
latent_model_input = torch.cat([latent] * 2) | |
sigma = scheduler.sigmas[i] | |
latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
with torch.no_grad(): | |
noise_pred = pipe.unet(latent_model_input.to(torch.float16), t, encoder_hidden_states=text_embeddings)["sample"] | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + config.GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond) | |
if loss_function and i%5 == 0: | |
latent = latent.detach().requires_grad_() | |
latent_x0 = latent - sigma * noise_pred | |
denoised_images = pipe.vae.decode((1/ 0.18215) * latent_x0).sample / 2 + 0.5 # range(0,1) | |
loss = loss_function(denoised_images) * config.LOSS_SCALE | |
print(f"loss {loss}") | |
cond_grad = torch.autograd.grad(loss, latent)[0] | |
latent = latent.detach() - cond_grad * sigma**2 | |
latent = scheduler.step(noise_pred,t, latent).prev_sample | |
return latent | |
def generate_images(pipe, seed_number, prompt, loss_function=None): | |
scheduler = LMSDiscreteScheduler(beta_start = 0.00085, | |
beta_end = 0.012, | |
beta_schedule = "scaled_linear", | |
num_train_timesteps = 1000) | |
scheduler.set_timesteps(config.NUM_INFERENCE_STEPS) | |
scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
text_embeddings = construct_text_embeddings(pipe=pipe, prompt=prompt) | |
latent = initialize_latent(seed_number=seed_number, pipe=pipe, scheduler=scheduler) | |
latent = run_prediction(pipe=pipe, text_embeddings=text_embeddings, | |
scheduler=scheduler, latent=latent, | |
loss_function=loss_function) | |
return latent |