|
|
""" |
|
|
Module implementing an adapter model that transforms LLM hidden states into SDXL UNET conditions. |
|
|
|
|
|
This module defines a neural network architecture that efficiently converts hidden states from |
|
|
LLMs (or text encoders) into conditioning vectors for SDXL's UNET, mimicking CLIP's function |
|
|
but in a more efficient manner. The adapter uses a multi-stage transformer-based approach |
|
|
to process and compress sequences while preserving semantic information. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""A standard transformer block with layer normalization, attention, and MLP layers.""" |
|
|
|
|
|
def __init__(self, dim: int, num_heads: int = 16, mlp_ratio: float = 4.0, dropout: float = 0.0): |
|
|
""" |
|
|
Initialize the transformer block. |
|
|
|
|
|
Args: |
|
|
dim: Dimension of the input features |
|
|
num_heads: Number of attention heads |
|
|
mlp_ratio: Ratio to determine the hidden dimension of the MLP |
|
|
dropout: Dropout probability for regularization |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
embed_dim=dim, |
|
|
num_heads=num_heads, |
|
|
batch_first=True, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, int(dim * mlp_ratio)), |
|
|
nn.GELU(), |
|
|
nn.Linear(int(dim * mlp_ratio), dim) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the transformer block. |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch_size, sequence_length, dim) |
|
|
mask: Attention mask where 1 indicates valid tokens and 0 indicates padding |
|
|
|
|
|
Returns: |
|
|
Output tensor of same shape as input |
|
|
""" |
|
|
|
|
|
normed_x = self.norm1(x) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
|
|
|
key_padding_mask = ~mask.bool() |
|
|
else: |
|
|
key_padding_mask = None |
|
|
|
|
|
attn_output, _ = self.attn( |
|
|
query=normed_x, |
|
|
key=normed_x, |
|
|
value=normed_x, |
|
|
key_padding_mask=key_padding_mask, |
|
|
need_weights=False |
|
|
) |
|
|
x = x + attn_output |
|
|
|
|
|
|
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
class LLMoSDXLAdapter(nn.Module): |
|
|
"""An adapter model that transforms LLM hidden states into SDXL UNET conditions. |
|
|
|
|
|
This class implements a multi-stage transformer architecture that efficiently converts |
|
|
hidden states from language models (like T5-Gemma) into conditioning vectors suitable |
|
|
for Stable Diffusion XL's UNET. The adapter processes the full sequence through wide |
|
|
transformer blocks, compresses it via cross-attention, and then refines the compressed |
|
|
representation through narrow transformer blocks. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llm_dim: int = 2304, |
|
|
sdxl_seq_dim: int = 2048, |
|
|
sdxl_pooled_dim: int = 1280, |
|
|
max_input_len: int = 512, |
|
|
target_seq_len: int = 308, |
|
|
n_wide_blocks: int = 3, |
|
|
n_narrow_blocks: int = 3, |
|
|
num_heads: int = 16, |
|
|
dropout: float = 0.05 |
|
|
): |
|
|
""" |
|
|
Initialize the adapter with specified dimensions and hyperparameters. |
|
|
|
|
|
Args: |
|
|
llm_dim: Dimension of input LLM hidden states |
|
|
sdxl_seq_dim: Target dimension for SDXL sequence embeddings |
|
|
sdxl_pooled_dim: Target dimension for SDXL pooled embeddings |
|
|
max_input_len: Maximum length of input sequences |
|
|
target_seq_len: Desired length after compression (e.g., 77*4=308) |
|
|
n_wide_blocks: Number of transformer blocks for full-sequence processing |
|
|
n_narrow_blocks: Number of transformer blocks for compressed-sequence processing |
|
|
num_heads: Number of attention heads in all multi-head attention layers |
|
|
dropout: Dropout probability for regularization throughout the network |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.max_input_len = max_input_len |
|
|
self.target_seq_len = target_seq_len |
|
|
self.num_heads = num_heads |
|
|
|
|
|
|
|
|
self.seq_projection = nn.Linear(llm_dim, sdxl_seq_dim) |
|
|
|
|
|
|
|
|
self.input_position_embeddings = nn.Parameter( |
|
|
torch.randn(1, max_input_len, sdxl_seq_dim) |
|
|
) |
|
|
self.output_position_embeddings = nn.Parameter( |
|
|
torch.randn(1, target_seq_len, sdxl_seq_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.wide_attention_blocks = nn.ModuleList([ |
|
|
TransformerBlock(sdxl_seq_dim, num_heads=num_heads, dropout=dropout) |
|
|
for _ in range(n_wide_blocks) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
self.compression_queries = nn.Parameter( |
|
|
torch.randn(1, target_seq_len, sdxl_seq_dim) |
|
|
) |
|
|
self.compression_attention = nn.MultiheadAttention( |
|
|
embed_dim=sdxl_seq_dim, |
|
|
num_heads=num_heads, |
|
|
batch_first=True, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.compression_norm = nn.LayerNorm(sdxl_seq_dim) |
|
|
|
|
|
|
|
|
self.compression_gate = nn.Sequential( |
|
|
nn.Linear(sdxl_seq_dim * 2, sdxl_seq_dim), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.narrow_attention_blocks = nn.ModuleList([ |
|
|
TransformerBlock(sdxl_seq_dim, num_heads=num_heads, dropout=dropout) |
|
|
for _ in range(n_narrow_blocks) |
|
|
]) |
|
|
|
|
|
|
|
|
self.pooling_token = nn.Parameter(torch.randn(1, 1, sdxl_seq_dim)) |
|
|
self.pooling_attention = nn.MultiheadAttention( |
|
|
embed_dim=sdxl_seq_dim, |
|
|
num_heads=num_heads, |
|
|
batch_first=True, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
|
|
|
self.pooled_projection = nn.Sequential( |
|
|
nn.Linear(sdxl_seq_dim, sdxl_seq_dim), |
|
|
nn.LayerNorm(sdxl_seq_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(sdxl_seq_dim, sdxl_pooled_dim) |
|
|
) |
|
|
|
|
|
def forward(self, llm_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass of the adapter model. |
|
|
|
|
|
Args: |
|
|
llm_hidden_states: Hidden states from LLM of shape (batch_size, seq_len, llm_dim) |
|
|
attention_mask: Mask indicating valid tokens (1) vs padding (0) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- text_embeds: Compressed sequence embeddings [batch_size, target_seq_len, sdxl_seq_dim] |
|
|
- pooled_text_embeds: Pooled text embeddings [batch_size, sdxl_pooled_dim] |
|
|
- attention_mask: Output mask for the compressed sequence |
|
|
- attention_orig: Original input attention mask |
|
|
""" |
|
|
batch_size, seq_len, _ = llm_hidden_states.shape |
|
|
|
|
|
|
|
|
hidden_states = self.seq_projection(llm_hidden_states) |
|
|
|
|
|
|
|
|
if seq_len > self.max_input_len: |
|
|
hidden_states = hidden_states[:, :self.max_input_len, :] |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask[:, :self.max_input_len] |
|
|
else: |
|
|
if seq_len < self.max_input_len: |
|
|
hidden_states = pad_to_length(hidden_states, self.max_input_len, dim=1) |
|
|
if attention_mask is not None: |
|
|
attention_mask = pad_to_length(attention_mask, self.max_input_len, dim=1, value=0) |
|
|
else: |
|
|
attention_mask = torch.ones(batch_size, self.max_input_len, device=hidden_states.device) |
|
|
attention_mask[:, seq_len:] = 0 |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + self.input_position_embeddings |
|
|
|
|
|
|
|
|
for block in self.wide_attention_blocks: |
|
|
hidden_states = block(hidden_states, attention_mask) |
|
|
|
|
|
|
|
|
queries = self.compression_queries.expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
key_padding_mask = ~attention_mask.bool() |
|
|
else: |
|
|
key_padding_mask = None |
|
|
|
|
|
compressed_sequence, compression_weights = self.compression_attention( |
|
|
query=queries, |
|
|
key=hidden_states, |
|
|
value=hidden_states, |
|
|
key_padding_mask=key_padding_mask, |
|
|
need_weights=True, |
|
|
average_attn_weights=False |
|
|
) |
|
|
|
|
|
|
|
|
gate_input = torch.cat([queries, compressed_sequence], dim=-1) |
|
|
gate = self.compression_gate(gate_input) |
|
|
compressed_sequence = gate * compressed_sequence + (1 - gate) * queries |
|
|
|
|
|
|
|
|
compressed_sequence = self.compression_norm(compressed_sequence) |
|
|
|
|
|
|
|
|
compressed_sequence = compressed_sequence + self.output_position_embeddings |
|
|
|
|
|
|
|
|
output_mask = torch.ones(batch_size, self.target_seq_len, device=hidden_states.device) |
|
|
|
|
|
for block in self.narrow_attention_blocks: |
|
|
compressed_sequence = block(compressed_sequence, output_mask) |
|
|
|
|
|
|
|
|
pooling_token = self.pooling_token.expand(batch_size, -1, -1) |
|
|
pooled_output, pooling_weights = self.pooling_attention( |
|
|
query=pooling_token, |
|
|
key=compressed_sequence, |
|
|
value=compressed_sequence, |
|
|
need_weights=True |
|
|
) |
|
|
|
|
|
|
|
|
pooled_embeds = self.pooled_projection(pooled_output.squeeze(1)) |
|
|
|
|
|
return { |
|
|
"text_embeds": compressed_sequence, |
|
|
"pooled_text_embeds": pooled_embeds, |
|
|
"attention_mask": output_mask, |
|
|
"attention_orig": attention_mask, |
|
|
} |
|
|
|
|
|
def pad_to_length(tensor: torch.Tensor, target_length: int, dim: int = 1, value: float = 0) -> torch.Tensor: |
|
|
""" |
|
|
Pad a tensor to a target length along a specified dimension. |
|
|
|
|
|
Args: |
|
|
tensor: Input tensor to be padded |
|
|
target_length: Desired length after padding |
|
|
dim: Dimension along which to pad (default: 1) |
|
|
value: Value to use for padding (default: 0) |
|
|
|
|
|
Returns: |
|
|
Padded tensor of the same dimensions as input except in the padded dimension |
|
|
""" |
|
|
current_length = tensor.size(dim) |
|
|
|
|
|
if current_length >= target_length: |
|
|
return tensor.narrow(dim, 0, target_length) |
|
|
|
|
|
pad_size = list(tensor.shape) |
|
|
pad_size[dim] = target_length - current_length |
|
|
|
|
|
padding = torch.full( |
|
|
pad_size, |
|
|
value, |
|
|
device=tensor.device, |
|
|
dtype=tensor.dtype |
|
|
) |
|
|
|
|
|
return torch.cat([tensor, padding], dim=dim) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
adapter_state_dict_path = "./file.safetensors" |
|
|
dtype = torch.float32 |
|
|
|
|
|
adapter_config = { |
|
|
"llm_dim": 2304, |
|
|
"sdxl_seq_dim": 2048, |
|
|
"sdxl_pooled_dim": 1280, |
|
|
"max_input_len": 512, |
|
|
"target_seq_len": 308, |
|
|
"n_wide_blocks": 3, |
|
|
"n_narrow_blocks": 3, |
|
|
"num_heads": 16, |
|
|
"dropout": 0.05 |
|
|
} |
|
|
|
|
|
|
|
|
adapter = LLMoSDXLAdapter(**adapter_config) |
|
|
|
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
|
adapter_state_dict = load_file(adapter_state_dict_path) |
|
|
adapter.load_state_dict(adapter_state_dict, strict=True) |
|
|
|
|
|
adapter.to(dtype).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_size = 4 |
|
|
llm_states = { |
|
|
"llm_hidden_states": torch.rand(batch_size, adapter.max_input_len, 2304, dtype=dtype).to(device), |
|
|
"attention_mask": torch.ones((batch_size, adapter.max_input_len), dtype=torch.int64).to(device), |
|
|
} |
|
|
|
|
|
x = adapter(**llm_states) |
|
|
|
|
|
text_embeddings = x['text_embeds'] |
|
|
pooled_embeddings = x['pooled_text_embeds'] |
|
|
|
|
|
print(f"Text embeddings shape: {text_embeddings.shape}") |
|
|
print(f"Pooled shape: {pooled_embeddings.shape}") |