Can gradients flow through this pipeline

#19
by gchen019 - opened

In my implementation, I'm trying to optimise the prompt embeddings through a training loop such that the reconstructed image latent from the inpainting model is as similar as the original image latent. But I'm getting an error "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn". So my assumption is that the gradient isnt flowing through the 'call' function. output_latent.requires_grad is also False. I'm not sure if there's something wrong with my implementation or something in the pipe thats preventing gradients from flowing. Please help, thank you.

def emb_opt(image, mask, emb, pooled_emb, pipe, true_latent, seed=42, num_inference_steps = 20, lr = 2e-3, epochs = 20, dim = 1024):

emb = emb.to('cuda')
emb.requires_grad_(True)

optimizer = torch.optim.Adam([emb], lr = lr)
criteria = torch.nn.MSELoss()

for i in range(epochs):

    print(f'epoch: ({i+1}/{epochs})')

    strength = 0.999

    optimizer.zero_grad()

    output_latent = pipe(generator = torch.Generator(device="cuda").manual_seed(seed),
                    prompt_embeds = emb,
                    pooled_prompt_embeds = pooled_emb,
                    image = image,
                    mask_image = mask,
                    strength = strength,
                    output_type = "latent",
                    num_inference_steps = num_inference_steps).images[0].to(dtype=torch.float32)

    print(output_latent.requires_grad)

    loss = criteria(output_latent, true_latent)

    print(f'loss: {loss.item()}')

    loss.backward()

    print(f'emb.grad: {emb.grad}')

    optimizer.step()

    if (i+1)%10==0:
        output.clear()
return emb
🧨Diffusers org

Could you open an issue on the diffusers github repository and post the code you're using for generations and information about your environment using diffusers-cli env?

gchen019 changed discussion title from Non deterministic Generation to Can gradients flow through this pipeline

Hi before I open an issue, could you kindly check if its an issue with my implementation or with the pipeline

Sign up or log in to comment