A newer version of this model is available: openai/gpt-oss-120b

πŸš€ Stable Diffusion with Transformers (Advanced Training)

This project demonstrates how to train a Stable Diffusion-like model using an image dataset with advanced Transformer-based denoising.
The implementation leverages PyTorch + Hugging Face Diffusers + Transformers.


πŸ“Œ Overview

Stable Diffusion is a Latent Diffusion Model (LDM) that generates images by:

  1. Encoding images into a latent space using a VAE (Variational Autoencoder).
  2. Adding Gaussian noise to the latents across multiple time steps.
  3. Training a denoising Transformer/UNet to remove noise step by step.
  4. Using a text encoder (CLIP) for prompt conditioning.
  5. Decoding the cleaned latents back to an image.

πŸ”¬ Architecture

graph TD;
    A[Input Image] -->|VAE Encoder| B[Latent Space];
    B -->|Add Noise| C[Noisy Latents];
    C -->|Transformer / UNet Denoiser| D[Clean Latents];
    D -->|VAE Decoder| E[Output Image];
    F[Text Prompt] -->|CLIP Encoder| C;
  • VAE β†’ Compresses image β†’ latent space

  • Transformer/UNet β†’ Learns to denoise latent

  • Text Encoder β†’ Aligns text + image

  • Noise Scheduler β†’ Controls forward & reverse diffusion

πŸ“‚ Dataset

  • Images should be resized (256x256) and normalized to [-1, 1].

  • Optional: Provide text captions for conditioning.

  • Example:

data/
 β”œβ”€β”€ class1/
 β”‚   β”œβ”€β”€ img1.png
 β”‚   └── img2.jpg
 β”œβ”€β”€ class2/
 β”‚   β”œβ”€β”€ img3.png
 β”‚   └── img4.jpg

βš™οΈ Training Algorithm

The training process for Stable Diffusion with Transformers follows these steps:

  1. Encode Images β†’ Pass input images through a VAE Encoder to obtain latent representations.
  2. Sample Noise & Timestep β†’ Randomly sample Gaussian noise and a timestep t.
  3. Add Noise β†’ Corrupt the latent vectors with noise according to timestep t.
  4. Text Conditioning β†’ Encode text prompts using CLIP (or another Transformer text encoder).
  5. Noise Prediction β†’ Feed the noisy latents + text embeddings into the Transformer/UNet to predict the added noise.
  6. Compute Loss β†’ Calculate the Mean Squared Error (MSE) between predicted noise and true noise.
  7. Backpropagation β†’ Update model weights using gradient descent.

flowchart TD
    A[Image] -->|VAE Encoder| B[Latent Space]
    B -->|Add Noise + t| C[Noisy Latents]
    D[Text Prompt] -->|CLIP Encoder| C
    C -->|Transformer / UNet| E[Predicted Noise]
    E -->|MSE Loss| F[Training Update]

πŸ§‘β€πŸ’» Example Training Code

from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torch, torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Dataset
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = datasets.ImageFolder("path_to_images", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Components
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
scheduler = DDPMScheduler(num_train_timesteps=1000)

device = "cuda" if torch.cuda.is_available() else "cpu"
vae, unet, text_encoder = vae.to(device), unet.to(device), text_encoder.to(device)

optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

# Training Loop
for epoch in range(10):
    for images, _ in dataloader:
        images = images.to(device)
        latents = vae.encode(images).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device)
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        text_inputs = tokenizer(["a photo"], padding="max_length", return_tensors="pt").to(device)
        text_embeds = text_encoder(text_inputs.input_ids).last_hidden_state

        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample
        loss = nn.MSELoss()(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch} | Loss: {loss.item()}")

πŸ’Ύ Saving & Inference

Save trained UNet

torch.save(unet.state_dict(), "unet_trained.pth")


# Inference pipeline
# 1. Sample random latent
# 2. Iteratively denoise with scheduler + trained UNet
# 3. Decode with VAE β†’ image

πŸ“– References

  • Stable Diffusion Paper

  • Hugging Face Diffusers

  • Diffusion Transformer (DiT)

βœ… Future Work

Replace UNet with pure Transformer (DiT).

Use larger text encoders (T5/DeBERTa).

Train with custom captioned datasets.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for ankitkushwaha90/Image_transformer_algorithm

Adapter
(7368)
this model

Dataset used to train ankitkushwaha90/Image_transformer_algorithm