stylecodes-sd15-demo / controlnet /attention_autoencoder.py
CiaraRowles's picture
Upload 4 files
934bde2 verified
raw
history blame
9.51 kB
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import datetime
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.normalization import GroupNorm
import base64
import numpy as np
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class AttentionAutoencoder(nn.Module):
def __init__(self, input_dim=768,output_dim=1280, d_model=512, latent_dim=20, seq_len=196, num_heads=4, num_layers=3, out_intermediate=512):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_dim = input_dim # Adjusted to 768
self.d_model = d_model
self.latent_dim = latent_dim
self.seq_len = seq_len # Adjusted to 196
self.out_intermediate = out_intermediate
self.output_dim = output_dim
# Positional Encoding
self.pos_encoder = PositionalEncoding(d_model)
# Input Projection (adjusted to project from input_dim=768 to d_model=512)
self.input_proj = nn.Linear(input_dim, d_model)
# Latent Initialization
self.latent_init = nn.Parameter(torch.randn(1, d_model))
# Cross-Attention Encoder
self.num_layers = num_layers
self.attention_layers = nn.ModuleList([
nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
for _ in range(num_layers)
])
# Latent Space Refinement
self.latent_proj = nn.Linear(d_model, latent_dim)
self.latent_norm = nn.LayerNorm(latent_dim)
self.latent_to_d_model = nn.Linear(latent_dim, d_model)
# Mapping latent to intermediate feature map
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True),
num_layers=2
)
# Output projection
self.output_proj = nn.Linear(d_model, output_dim)
self.tgt_init = nn.Parameter(torch.randn(1, d_model))
def encode(self, src):
# src shape: [batch_size, seq_len (196), input_dim (768)]
batch_size, seq_len, input_dim = src.shape
# Project input_dim (768) to d_model (512)
src = self.input_proj(src) # Shape: [batch_size, seq_len (196), d_model (512)]
src = self.pos_encoder(src) # Add positional encoding
# Latent initialization
latent = self.latent_init.repeat(batch_size, 1).unsqueeze(1) # Shape: [batch_size, 1, d_model]
# Cross-attend latent with input sequence
for i in range(self.num_layers):
latent, _ = self.attention_layers[i](latent, src, src)
# Project to latent dimension and normalize
latent = self.latent_proj(latent.squeeze(1)) # Shape: [batch_size, latent_dim]
latent = self.latent_norm(latent)
return latent
def decode(self, latent, seq_w, seq_h):
batch_size = latent.size(0)
target_seq_len = seq_w * seq_h
# Project latent_dim back to d_model
memory = self.latent_to_d_model(latent).unsqueeze(1) # Shape: [batch_size, 1, d_model]
# Target initialization
# Repeat the learned target initialization to match the target sequence length
tgt = self.tgt_init.repeat(batch_size, target_seq_len, 1) # Shape: [batch_size, target_seq_len, d_model]
# Apply positional encoding
tgt = self.pos_encoder(tgt)
# Apply transformer decoder
output = self.transformer_decoder(tgt, memory) # Shape: [batch_size, target_seq_len, d_model]
# Project to output_dim
output = self.output_proj(output) # Shape: [batch_size, target_seq_len, output_dim]
# Reshape output to (batch_size, seq_w, seq_h, output_dim)
output = output.view(batch_size, seq_w, seq_h, self.output_dim)
# Permute dimensions to (batch_size, output_dim, seq_w, seq_h)
output = output.permute(0, 3, 1, 2) # Shape: [batch_size, output_dim, seq_w, seq_h]
return output
def forward(self, src, seq_w, seq_h):
latent = self.encode(src)
output = self.decode(latent, seq_w, seq_h)
return output
def encode_to_base64(self, latent_vector, bits_per_element):
max_int = 2 ** bits_per_element - 1
q_latent = ((latent_vector + 1) * (max_int / 2)).clip(0, max_int).astype(np.uint8)
byte_array = q_latent.tobytes()
encoded_string = base64.b64encode(byte_array).decode('utf-8')
# Remove padding characters
return encoded_string.rstrip('=')
def decode_from_base64(self, encoded_string, bits_per_element, latentdim):
# Add back padding if it's missing
missing_padding = len(encoded_string) % 4
if missing_padding:
encoded_string += '=' * (4 - missing_padding)
byte_array = base64.b64decode(encoded_string)
q_latent = np.frombuffer(byte_array, dtype=np.uint8)[:latentdim]
max_int = 2 ** bits_per_element - 1
latent_vector = q_latent.astype(np.float32) * 2 / max_int - 1
return latent_vector
def forward_encoding(self, src, seq_w, seq_h):
"""
Encodes the input `src` into a latent representation, encodes it to a Base64 string,
decodes it back to the latent space, and then decodes it to the output.
Args:
src: The input data to encode.
Returns:
output: The decoded output from the latent representation.
"""
# Step 1: Encode the input to latent space
latent = self.encode(src) # latent is of shape (batch_size, self.latentdim)
batch_size, latentdim = latent.shape
# Ensure bits_per_element is appropriate
bits_per_element = int(120 / latentdim) # Example: latentdim = 20, bits_per_element = 6
if bits_per_element > 8:
raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
encoded_strings = []
# Step 2: Encode each latent vector to a Base64 string
for i in range(batch_size):
latent_vector = latent[i].cpu().numpy()
encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
encoded_strings.append(encoded_string)
decoded_latents = []
# Step 3: Decode each Base64 string back to the latent vector
for i, encoded_string in enumerate(encoded_strings):
print(encoded_string)
decoded_latent = self.decode_from_base64(encoded_string, bits_per_element, latentdim)
decoded_latents.append(decoded_latent)
# Step 4: Convert the list of decoded latents back to a tensor
decoded_latents = torch.tensor(decoded_latents, dtype=latent.dtype, device=latent.device)
# Step 5: Decode the latent tensor into the output
output = self.decode(decoded_latents,seq_w, seq_h)
return output, encoded_strings
def forward_from_stylecode (self, stylecode, seq_w, seq_h,dtyle,device):
latentdim = 20
bits_per_element = 6
decoded_latents = []
#for i, encoded_string in enumerate(stylecode):
decoded_latent = self.decode_from_base64(stylecode, bits_per_element, latentdim)
decoded_latents.append(decoded_latent)
# Step 4: Convert the list of decoded latents back to a tensor
decoded_latents = torch.tensor(decoded_latents, dtype=dtyle, device=device)
output = self.decode(decoded_latents, seq_w, seq_h)
return output
@torch.no_grad()
def make_stylecode (self,src):
src = src.to("cuda")
self = self.to("cuda")
print(src.device,self.device,self.input_proj.weight.device)
latent = self.encode(src) # latent is of shape (batch_size, self.latentdim)
batch_size, latentdim = latent.shape
# Ensure bits_per_element is appropriate
bits_per_element = int(120 / latentdim) # Example: latentdim = 20, bits_per_element = 6
if bits_per_element > 8:
raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
encoded_strings = []
# Step 2: Encode each latent vector to a Base64 string
for i in range(batch_size):
latent_vector = latent[i].cpu().numpy()
encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
encoded_strings.append(encoded_string)
return encoded_strings