Accessing the latent representation of the image?

#164
by Kelmeilia - opened

I have been trying to access the latent representation of a text2image pipeline, I use AutoPipelineForText2Image and cuda. My goal is to blend two stable diffusion images' latent representation and see what comes out from the average.

However, I have extreme difficulties with this. I have tried many ways (except switching to keras framework) and it never seems to work. I understand that there is a AutoPipelineForText2Image.encode_image() that I could use for the fabricated images, but I can't get it to work - I just get exotic internal errors, like incompatible dtypes etc. I have also trouble finding the documentation for encoding/decoding, so my coding so far has been much like funbling in the dark..

Could someone help me how to access the latent representation of SD images or point me into some docs that could help?

I am sorry, there seemed to be a little inaccuracy in my previous post; accessing txt2img pipeline latents seem a pretty straight-forward nowadays, but I am trying to access the latents of img2img pipeline. Ultimately my goal is to blend latents of SD latents based on initial images.

So far I have come up with this:

import numpy as np
import torch
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler
from PIL import Image

# Init pipeline, use cuda
model_name = "stabilityai/stable-diffusion-2-base"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_name, safety_checker=None).to("cuda")

# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

init_image = Image.open("sample_image.jpg")

# The prompt
prompt = "japanese wood painting"

# Disable gradient computation for inference
with torch.no_grad():
    # Generate latent representations
    latent_result = pipe(prompt=prompt,
                         image=init_image,
                         strength=0.05,
                         guidance_scale=10,
                         output_type="latent")

    # Extract the latent representation
    latents = latent_result.images

    # Have to use Vae as pipe.decode_latents is deprecated :(
    latent_image_tensor = pipe.vae.decode(latents).sample  # Get the tensor from DecoderOutput

    # Rescale the tensor values from [-1, 1] to [0, 1]
    latent_image_tensor = (latent_image_tensor / 2 + 0.5).clamp(0, 1)

    # Move the tensor to the CPU and change the dimensions to (height, width, channels)
    latent_image_numpy = latent_image_tensor.cpu().permute(0, 2, 3, 1).numpy()

    # Convert to PIL image
    latent_image = Image.fromarray((latent_image_numpy[0] * 255).astype(np.uint8))

    # Generate final image from latent representation
    final_image = pipe(prompt=prompt,
                       image=latent_image,
                       strength=0.3,
                       guidance_scale=10,
                       output_type="pil").images[0]
    
width, height = init_image.size
combined_image = Image.new('RGB', (width * 3, height))

combined_image.paste(init_image, (0, 0))
combined_image.paste(latent_image, (width, 0))
combined_image.paste(final_image, (width * 2, 0))

combined_image.show()

However, the colours are way off:

huggingiiin.png

I must be doing something wrong with the tensor conversions, but what? Could someone help me

I have continued to try and fix this, but to no avail...

However, I strongly suspect that the problem lies in the output of VAE decoder and this line:

latent_image_tensor = pipe.vae.decode(latents).sample

However, trying to transform the colours, contrast or any such tweaking don't seem to work. Am I using the vae decoder right? Are there alternatives or parameters that I could use? I haven't found a comprehensive documentation to the vae decoder and it's parameters, but I am notoriously bad at finding docs.

I found this though:
https://github.com/huggingface/diffusers/issues/2871

It seems to describe a problem with output_type="latent" and I guess that the issue should be fixed. However, I am not good enough with Python to make deductions from the source code, could someone with more skills check this?

Sign up or log in to comment