Rouwei-T5Gemma-adapter_v0.2 / code /llm_to_sdxl_adapter.py
Minthy's picture
Upload folder using huggingface_hub
5a8dc28 verified
"""
Module implementing an adapter model that transforms LLM hidden states into SDXL UNET conditions.
This module defines a neural network architecture that efficiently converts hidden states from
LLMs (or text encoders) into conditioning vectors for SDXL's UNET, mimicking CLIP's function
but in a more efficient manner. The adapter uses a multi-stage transformer-based approach
to process and compress sequences while preserving semantic information.
"""
import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
class TransformerBlock(nn.Module):
"""A standard transformer block with layer normalization, attention, and MLP layers."""
def __init__(self, dim: int, num_heads: int = 16, mlp_ratio: float = 4.0, dropout: float = 0.0):
"""
Initialize the transformer block.
Args:
dim: Dimension of the input features
num_heads: Number of attention heads
mlp_ratio: Ratio to determine the hidden dimension of the MLP
dropout: Dropout probability for regularization
"""
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True,
dropout=dropout
)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim)
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Forward pass through the transformer block.
Args:
x: Input tensor of shape (batch_size, sequence_length, dim)
mask: Attention mask where 1 indicates valid tokens and 0 indicates padding
Returns:
Output tensor of same shape as input
"""
# Self-attention with layer normalization
normed_x = self.norm1(x)
if mask is not None:
# Convert attention mask to key_padding_mask format
# True means ignore this position, so we invert our mask
key_padding_mask = ~mask.bool()
else:
key_padding_mask = None
attn_output, _ = self.attn(
query=normed_x,
key=normed_x,
value=normed_x,
key_padding_mask=key_padding_mask,
need_weights=False
)
x = x + attn_output
# MLP with residual connection
x = x + self.mlp(self.norm2(x))
return x
class LLMoSDXLAdapter(nn.Module):
"""An adapter model that transforms LLM hidden states into SDXL UNET conditions.
This class implements a multi-stage transformer architecture that efficiently converts
hidden states from language models (like T5-Gemma) into conditioning vectors suitable
for Stable Diffusion XL's UNET. The adapter processes the full sequence through wide
transformer blocks, compresses it via cross-attention, and then refines the compressed
representation through narrow transformer blocks.
"""
def __init__(
self,
llm_dim: int = 2304,
sdxl_seq_dim: int = 2048,
sdxl_pooled_dim: int = 1280,
max_input_len: int = 512,
target_seq_len: int = 308,
n_wide_blocks: int = 3,
n_narrow_blocks: int = 3,
num_heads: int = 16,
dropout: float = 0.05
):
"""
Initialize the adapter with specified dimensions and hyperparameters.
Args:
llm_dim: Dimension of input LLM hidden states
sdxl_seq_dim: Target dimension for SDXL sequence embeddings
sdxl_pooled_dim: Target dimension for SDXL pooled embeddings
max_input_len: Maximum length of input sequences
target_seq_len: Desired length after compression (e.g., 77*4=308)
n_wide_blocks: Number of transformer blocks for full-sequence processing
n_narrow_blocks: Number of transformer blocks for compressed-sequence processing
num_heads: Number of attention heads in all multi-head attention layers
dropout: Dropout probability for regularization throughout the network
"""
super().__init__()
self.max_input_len = max_input_len
self.target_seq_len = target_seq_len
self.num_heads = num_heads
# Projection from LLM dimension to SDXL sequence dimension
self.seq_projection = nn.Linear(llm_dim, sdxl_seq_dim)
# Learnable positional embeddings for input and output sequences
self.input_position_embeddings = nn.Parameter(
torch.randn(1, max_input_len, sdxl_seq_dim)
)
self.output_position_embeddings = nn.Parameter(
torch.randn(1, target_seq_len, sdxl_seq_dim)
)
# Wide transformer blocks for processing the full-length sequence (512 tokens)
self.wide_attention_blocks = nn.ModuleList([
TransformerBlock(sdxl_seq_dim, num_heads=num_heads, dropout=dropout)
for _ in range(n_wide_blocks)
])
# Compression mechanism using cross-attention with learnable queries
# This reduces sequence length from 512 to 308 while preserving information
self.compression_queries = nn.Parameter(
torch.randn(1, target_seq_len, sdxl_seq_dim)
)
self.compression_attention = nn.MultiheadAttention(
embed_dim=sdxl_seq_dim,
num_heads=num_heads,
batch_first=True,
dropout=dropout
)
# Normalization layer after compression for stability
self.compression_norm = nn.LayerNorm(sdxl_seq_dim)
# Gating mechanism to control information flow during compression
self.compression_gate = nn.Sequential(
nn.Linear(sdxl_seq_dim * 2, sdxl_seq_dim),
nn.Sigmoid()
)
# Narrow transformer blocks for processing the compressed sequence (308 tokens)
self.narrow_attention_blocks = nn.ModuleList([
TransformerBlock(sdxl_seq_dim, num_heads=num_heads, dropout=dropout)
for _ in range(n_narrow_blocks)
])
# Attention-based pooling using a learnable [CLS]-like token
self.pooling_token = nn.Parameter(torch.randn(1, 1, sdxl_seq_dim))
self.pooling_attention = nn.MultiheadAttention(
embed_dim=sdxl_seq_dim,
num_heads=num_heads,
batch_first=True,
dropout=dropout
)
# Final projection from sequence dimension to pooled embedding dimension
self.pooled_projection = nn.Sequential(
nn.Linear(sdxl_seq_dim, sdxl_seq_dim),
nn.LayerNorm(sdxl_seq_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(sdxl_seq_dim, sdxl_pooled_dim)
)
def forward(self, llm_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""
Forward pass of the adapter model.
Args:
llm_hidden_states: Hidden states from LLM of shape (batch_size, seq_len, llm_dim)
attention_mask: Mask indicating valid tokens (1) vs padding (0)
Returns:
Dictionary containing:
- text_embeds: Compressed sequence embeddings [batch_size, target_seq_len, sdxl_seq_dim]
- pooled_text_embeds: Pooled text embeddings [batch_size, sdxl_pooled_dim]
- attention_mask: Output mask for the compressed sequence
- attention_orig: Original input attention mask
"""
batch_size, seq_len, _ = llm_hidden_states.shape
# Project to SDXL dimensionality
hidden_states = self.seq_projection(llm_hidden_states)
# Pad or truncate to maximum input length
if seq_len > self.max_input_len:
hidden_states = hidden_states[:, :self.max_input_len, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :self.max_input_len]
else:
if seq_len < self.max_input_len:
hidden_states = pad_to_length(hidden_states, self.max_input_len, dim=1)
if attention_mask is not None:
attention_mask = pad_to_length(attention_mask, self.max_input_len, dim=1, value=0)
else:
attention_mask = torch.ones(batch_size, self.max_input_len, device=hidden_states.device)
attention_mask[:, seq_len:] = 0
# Add positional embeddings
hidden_states = hidden_states + self.input_position_embeddings
# Stage 1: Process full-length sequence through wide transformer blocks
for block in self.wide_attention_blocks:
hidden_states = block(hidden_states, attention_mask)
# Stage 2: Compress sequence from max_input_len to target_seq_len using cross-attention
queries = self.compression_queries.expand(batch_size, -1, -1)
# Prepare key padding mask (True means ignore this position)
if attention_mask is not None:
key_padding_mask = ~attention_mask.bool()
else:
key_padding_mask = None
compressed_sequence, compression_weights = self.compression_attention(
query=queries,
key=hidden_states,
value=hidden_states,
key_padding_mask=key_padding_mask,
need_weights=True,
average_attn_weights=False
)
# Apply gating mechanism to control information flow
gate_input = torch.cat([queries, compressed_sequence], dim=-1)
gate = self.compression_gate(gate_input)
compressed_sequence = gate * compressed_sequence + (1 - gate) * queries
# Normalize after compression
compressed_sequence = self.compression_norm(compressed_sequence)
# Add positional embeddings for the compressed sequence
compressed_sequence = compressed_sequence + self.output_position_embeddings
# Stage 3: Process compressed sequence through narrow transformer blocks
output_mask = torch.ones(batch_size, self.target_seq_len, device=hidden_states.device)
for block in self.narrow_attention_blocks:
compressed_sequence = block(compressed_sequence, output_mask)
# Stage 4: Generate pooled embeddings using attention-based pooling
pooling_token = self.pooling_token.expand(batch_size, -1, -1)
pooled_output, pooling_weights = self.pooling_attention(
query=pooling_token,
key=compressed_sequence,
value=compressed_sequence,
need_weights=True
)
# Project to final pooled dimension
pooled_embeds = self.pooled_projection(pooled_output.squeeze(1))
return {
"text_embeds": compressed_sequence,
"pooled_text_embeds": pooled_embeds,
"attention_mask": output_mask,
"attention_orig": attention_mask,
}
def pad_to_length(tensor: torch.Tensor, target_length: int, dim: int = 1, value: float = 0) -> torch.Tensor:
"""
Pad a tensor to a target length along a specified dimension.
Args:
tensor: Input tensor to be padded
target_length: Desired length after padding
dim: Dimension along which to pad (default: 1)
value: Value to use for padding (default: 0)
Returns:
Padded tensor of the same dimensions as input except in the padded dimension
"""
current_length = tensor.size(dim)
if current_length >= target_length:
return tensor.narrow(dim, 0, target_length)
pad_size = list(tensor.shape)
pad_size[dim] = target_length - current_length
padding = torch.full(
pad_size,
value,
device=tensor.device,
dtype=tensor.dtype
)
return torch.cat([tensor, padding], dim=dim)
if __name__ == "__main__":
# Example usage and testing
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
adapter_state_dict_path = "./file.safetensors"
dtype = torch.float32
adapter_config = {
"llm_dim": 2304,
"sdxl_seq_dim": 2048,
"sdxl_pooled_dim": 1280,
"max_input_len": 512,
"target_seq_len": 308,
"n_wide_blocks": 3,
"n_narrow_blocks": 3,
"num_heads": 16,
"dropout": 0.05
}
# Initialize the adapter model
adapter = LLMoSDXLAdapter(**adapter_config)
from safetensors.torch import load_file
# Load pre-trained weights
adapter_state_dict = load_file(adapter_state_dict_path)
adapter.load_state_dict(adapter_state_dict, strict=True)
adapter.to(dtype).to(device)
with torch.no_grad():
batch_size = 4
llm_states = {
"llm_hidden_states": torch.rand(batch_size, adapter.max_input_len, 2304, dtype=dtype).to(device),
"attention_mask": torch.ones((batch_size, adapter.max_input_len), dtype=torch.int64).to(device),
}
x = adapter(**llm_states)
text_embeddings = x['text_embeds']
pooled_embeddings = x['pooled_text_embeds']
print(f"Text embeddings shape: {text_embeddings.shape}")
print(f"Pooled shape: {pooled_embeddings.shape}")