StableDiffusionModel / image_generator.py
bala1802's picture
Upload 6 files
1975737 verified
raw
history blame
No virus
3.38 kB
import torch
from tqdm.auto import tqdm
from diffusers import LMSDiscreteScheduler
import config
def construct_text_embeddings(pipe, prompt):
text_input = pipe.tokenizer(prompt, padding='max_length',
max_length = pipe.tokenizer.model_max_length, truncation= True,
return_tensors="pt")
uncond_input = pipe.tokenizer([""] * config.BATCH_SIZE, padding="max_length",
max_length= text_input.input_ids.shape[-1],
return_tensors="pt")
with torch.no_grad():
text_input_embeddings = pipe.text_encoder(text_input.input_ids.to(config.DEVICE))[0]
with torch.no_grad():
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(config.DEVICE))[0]
text_embeddings = torch.cat([uncond_embeddings, text_input_embeddings])
return text_embeddings
def initialize_latent(seed_number, pipe, scheduler):
generator = torch.manual_seed(seed_number)
latent = torch.randn((config.BATCH_SIZE, pipe.unet.config.in_channels,
config.HEIGHT//8, config.WIDTH//8),
generator = generator).to(torch.float16)
latent = latent.to(config.DEVICE)
latent = latent * scheduler.init_noise_sigma
return latent
def run_prediction(pipe, text_embeddings, scheduler, latent, loss_function=None):
for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):
latent_model_input = torch.cat([latent] * 2)
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
with torch.no_grad():
noise_pred = pipe.unet(latent_model_input.to(torch.float16), t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + config.GUIDANCE_SCALE * (noise_pred_text - noise_pred_uncond)
if loss_function and i%5 == 0:
latent = latent.detach().requires_grad_()
latent_x0 = latent - sigma * noise_pred
denoised_images = pipe.vae.decode((1/ 0.18215) * latent_x0).sample / 2 + 0.5 # range(0,1)
loss = loss_function(denoised_images) * config.LOSS_SCALE
print(f"loss {loss}")
cond_grad = torch.autograd.grad(loss, latent)[0]
latent = latent.detach() - cond_grad * sigma**2
latent = scheduler.step(noise_pred,t, latent).prev_sample
return latent
def generate_images(pipe, seed_number, prompt, loss_function=None):
scheduler = LMSDiscreteScheduler(beta_start = 0.00085,
beta_end = 0.012,
beta_schedule = "scaled_linear",
num_train_timesteps = 1000)
scheduler.set_timesteps(config.NUM_INFERENCE_STEPS)
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
text_embeddings = construct_text_embeddings(pipe=pipe, prompt=prompt)
latent = initialize_latent(seed_number=seed_number, pipe=pipe, scheduler=scheduler)
latent = run_prediction(pipe=pipe, text_embeddings=text_embeddings,
scheduler=scheduler, latent=latent,
loss_function=loss_function)
return latent