# Explaining the SDXL latent space

TL;DR

or check out the interactive demonstration

# Table of Contents

A Short background story

The 4 channels of the SDXL latents

The 8-bit pixel space has 3 channels

The SDXL latent representation of an image has 4 channels

Direct conversion of SDXL latents to RGB with a linear approximation

A probable reason why the SDXL color range is biased towards yellow

What needs correcting?

Let's take an example output from SDXL

A complete demonstration

Increasing color range / removing color bias

Long prompts at high guidance scales becoming possible

## A short background story

*Special thanks to: Ollin Boer Bohan Haoming, Cristina Segalin and Birchlabs for helping with information, discussion and knowledge!*

I was creating correction filters for the SDXL inference process to an UI I'm creating for diffusion models.

After having many years of experience with image correction, I wanted the fundamental capability to improve the actual output from SDXL. There were many techniques which I wanted available in the UX, which I set out to fix myself. I noticed that SDXL output is almost always either noisy in regular patterns or overly smooth. The color space always needed white balancing, with a biased and restricted color range, simply because of how SD models work.

Making corrections in a post process after the image is generated and converted to 8-bit RGB made very little sense, if it was possible to improve the information and color range before the actual output.

The most important thing to know in order to create filters and correction tools is to understand the data you are working with.

This led me to an experimental exploration of the SDXL latents with the intention of understanding them.
The tensor, which the diffusion models based on the SDXL architecture work with, looks like this:

```
[batch_size, 4 channels, height (y), width (x)]
```

My first question was simply "**What exactly are these 4 channels?**".
To which most answers I received were along the lines of "It's not something that a human can understand."

But it is most definitely understandable. It's even very easy to understand and useful to know.

## The 4 channels of the SDXL latents

For a 1024×1024px image generated by SDXL, the latents tensor is 128×128px, where every pixel in the latent space represents 64 (8×8) pixels in the pixel space. If we generate and decode the latents into a standard 8-bit jpg image, then...

### The 8-bit pixel space has 3 channels

Red (R), Green (G) and Blue (B), each with 256 possible values ranging between 0-255. So, to store the full information of 64 pixels, we need to be able to store 64×256 = 16,384 values, per channel, in every latent pixel.

### The SDXL latent representation of an image has 4 channels

*Click the heading for an interactive demo!*

**0:** Luminance

**1:** Cyan/Red => equivalent to rgb(0, 255, 255)/rgb(255, 0, 0)

**2:** Lime/Medium Purple => equivalent to rgb(127, 255, 0)/rgb(127, 0, 255)

**3:** Pattern/structure.

If each value can range between -4 and 4 at the point of decoding, then in a 16-bit floating point format with half precision, each latent pixel can contain 16,384 distinct values for each of the 4 channels.

### Direct conversion of SDXL latents to RGB with a linear approximation

With this understanding, we can create an approximation function which directly converts the latents to RGB:

```
def latents_to_rgb(latents):
weights = (
(60, -60, 25, -70),
(60, -5, 15, -50),
(60, 10, -5, -35)
)
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
return Image.fromarray(image_array)
```

Here we have the latents_to_rgb result and a regular decoded output, resized for comparison:

### A probable reason why the SDXL color range is biased towards yellow

Relatively few things in nature are blue, or white. These colors are most prominent in the sky, during enjoyable conditions. So, the model, knowing reality through images, thinks in luminance (channel 0) cyan/red (channel 1) and lime/medium purple (channel 2), where Red and Green are primary and blue is secondary. This is why very often, SDXL generations are biased towards yellow (red + green).

During inference, the values in the tensor will begin at `min < -30`

and `max > 30`

and the min/max boundary at time of decoding is around `-4`

to `4`

. At higher `guidance_scale`

the values will have a higher difference between `min`

and `max`

.

One key in understanding the boundary is to look at what happens in the decoding process:

```
decoded = vae.decode(latents / vae.scaling_factor).sample # (SDXL vae.scaling_factor = 0.13025)
decoded = decoded.div(2).add(0.5).clamp(0, 1) # The dynamics outside of 0 to 1 at this point will be lost
```

If the values at this point are outside of the range 0 to 1, some information will be lost in the clamp. So if we can make corrections during denoising to serve the VAE what it expects, we may get better results.

## What needs correcting?

How do you sharpen a blurry image, white balance, improve detail, increase contrast or increase the color range? The best way is to begin with a sharp image, which is correctly white balanced with great contrast, crisp details and a high range.

It's far easier to blur a sharp image, shift the color balance, reduce contrast, get nonsensical details and limit the color range than to improve it.

SDXL has a very prominent tendency to color bias and put values outside of the actual boundaries (left image). Which is easily solved by centering the values and getting them within the boundaries (right image):

```
def center_tensor(input_tensor, per_channel_shift=1, full_tensor_shift=1, channels=[0, 1, 2, 3]):
for channel in channels:
input_tensor[0, channel] -= input_tensor[0, channel].mean() * per_channel_shift
return input_tensor - input_tensor.mean() * full_tensor_shift
```

## Let's take an example output from SDXL

```
seed: 77777777
guidance_scale: 20 # A high guidance scale can be fixed too
steps with base: 23
steps with refiner: 10
prompt: Cinematic.Beautiful smile action woman in detailed white mecha gundam armor with red details,green details,blue details,colorful,star wars universe,lush garden,flowers,volumetric lighting,perfect eyes,perfect teeth,blue sky,bright,intricate details,extreme detail of environment,infinite focus,well lit,interesting clothes,radial gradient fade,directional particle lighting,wow
negative_prompt: helmet, bokeh, painting, artwork, blocky, blur, ugly, old, boring, photoshopped, tired, wrinkles, scar, gray hair, big forehead, crosseyed, dumb, stupid, cockeyed, disfigured, crooked, blurry, unrealistic, grayscale, bad anatomy, unnatural irises, no pupils, blurry eyes, dark eyes, extra limbs, deformed, disfigured eyes, out of frame, no irises, assymetrical face, broken fingers, extra fingers, disfigured hands
```

**Notice** that I've purposely chosen a high guidance scale.

How can we fix this image? It's half painting, half photograph. The colors range is biased towards yellow. To the right is a fixed generation with the exact same settings.

But also with a sensible `guidance_scale`

set to 7.5, we can still conclude that the fixed output is better, without nonsensical details and correct white balance.

There are many things we can do in the latent space to generally improve a generation and there are some very simple things which we can do to target specific errors in a generation:

### Outlier removal

This will control the amount of nonsensical details, by pruning values that are the farthest from the mean of the distribution. It also helps in generating at higher guidance_scale.

```
# Shrinking towards the mean (will also remove outliers)
def soft_clamp_tensor(input_tensor, threshold=3.5, boundary=4):
if max(abs(input_tensor.max()), abs(input_tensor.min())) < 4:
return input_tensor
channel_dim = 1
max_vals = input_tensor.max(channel_dim, keepdim=True)[0]
max_replace = ((input_tensor - threshold) / (max_vals - threshold)) * (boundary - threshold) + threshold
over_mask = (input_tensor > threshold)
min_vals = input_tensor.min(channel_dim, keepdim=True)[0]
min_replace = ((input_tensor + threshold) / (min_vals + threshold)) * (-boundary + threshold) - threshold
under_mask = (input_tensor < -threshold)
return torch.where(over_mask, max_replace, torch.where(under_mask, min_replace, input_tensor))
```

### Color balancing and increased range

I have two main methods of achieving this. The first one is to shrink towards the mean while normalizing the values (Which will also remove outliers) and the second is to fix when the values get biased towards some color. This also helps in generating at higher guidance_scale.

```
# Center tensor (balance colors)
def center_tensor(input_tensor, channel_shift=1, full_shift=1, channels=[0, 1, 2, 3]):
for channel in channels:
input_tensor[0, channel] -= input_tensor[0, channel].mean() * channel_shift
return input_tensor - input_tensor.mean() * full_shift
```

### Tensor maximizing

This is basically done by multiplying the tensors by a very small amount like `1e-5`

for a few steps and to make sure that the final tensor is using the full possible range ( closer to -4/4) before converting to RGB. Remember, in the pixel space, it's easier to reduce contrast, saturation and sharpness with intact dynamics than to increase it.

```
# Maximize/normalize tensor
def maximize_tensor(input_tensor, boundary=4, channels=[0, 1, 2]):
min_val = input_tensor.min()
max_val = input_tensor.max()
normalization_factor = boundary / max(abs(min_val), abs(max_val))
input_tensor[0, channels] *= normalization_factor
return input_tensor
```

### Callback implementation example

```
def callback(pipe, step_index, timestep, cbk):
if timestep > 950:
threshold = max(cbk["latents"].max(), abs(cbk["latents"].min())) * 0.998
cbk["latents"] = soft_clamp_tensor(cbk["latents"], threshold*0.998, threshold)
if timestep > 700:
cbk["latents"] = center_tensor(cbk["latents"], 0.8, 0.8)
if timestep > 1 and timestep < 100:
cbk["latents"] = center_tensor(cbk["latents"], 0.6, 1.0)
cbk["latents"] = maximize_tensor(cbk["latents"])
return cbk
image = base(
prompt,
guidance_scale = guidance_scale,
callback_on_step_end=callback,
callback_on_step_end_inputs=["latents"]
).images[0]
```

This simple implementation of the three methods are used in the last set of images, with the women in the garden.

## A complete demonstration

*Click the heading or this link for an interactive demo!*

This demonstration uses a more advanced implementation of the techniques by detecting outliers using Z-score, by shifting towards mean dynamically and by applying strength to each technique.

#### Original SDXL (too yellow) and slight modification (white balanced)

#### Medium modification and hard modification (both with all 3 techniques applied)

## Increasing color range / removing color bias

For the below, SDXL has limited the color range to red and green in the regular output. Because there is nothing in the prompt suggesting that there is such a thing as blue. This is a rather good generation, but the color range has become restricted.

If you give someone a palette of black, red, green and yellow and then tell them to paint a clear blue sky, the natural response is to ask you to supply blue and white.

To include blue in the generation, we can simply realign the color space when it gets restricted and SDXL will appropriately include the full color spectrum in the generation.

## Long prompts at high guidance scales becoming possible

Here is a typical scenario, where the increased color range makes the whole prompt possible.

This example apply the simple, hard modification shown earlier, to illustrate the difference more clearly.

**prompt:** Photograph of woman in red dress in a luxury garden surrounded with **blue**, yellow, purple and flowers in **many colors**, high class, award-winning photography, Portra 400, full format. **blue sky**, intricate details even to the smallest particle, extreme detail of the environment, sharp portrait, **well lit**, **interesting** outfit, beautiful shadows, **bright**, photoquality, ultra realistic, masterpiece

#### Here are some more comparisons on the same concept

*Keep in mind that these all just use the same static modifications.*