Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,510 Bytes
934bde2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import datetime
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.normalization import GroupNorm
import base64
import numpy as np
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class AttentionAutoencoder(nn.Module):
def __init__(self, input_dim=768,output_dim=1280, d_model=512, latent_dim=20, seq_len=196, num_heads=4, num_layers=3, out_intermediate=512):
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_dim = input_dim # Adjusted to 768
self.d_model = d_model
self.latent_dim = latent_dim
self.seq_len = seq_len # Adjusted to 196
self.out_intermediate = out_intermediate
self.output_dim = output_dim
# Positional Encoding
self.pos_encoder = PositionalEncoding(d_model)
# Input Projection (adjusted to project from input_dim=768 to d_model=512)
self.input_proj = nn.Linear(input_dim, d_model)
# Latent Initialization
self.latent_init = nn.Parameter(torch.randn(1, d_model))
# Cross-Attention Encoder
self.num_layers = num_layers
self.attention_layers = nn.ModuleList([
nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
for _ in range(num_layers)
])
# Latent Space Refinement
self.latent_proj = nn.Linear(d_model, latent_dim)
self.latent_norm = nn.LayerNorm(latent_dim)
self.latent_to_d_model = nn.Linear(latent_dim, d_model)
# Mapping latent to intermediate feature map
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True),
num_layers=2
)
# Output projection
self.output_proj = nn.Linear(d_model, output_dim)
self.tgt_init = nn.Parameter(torch.randn(1, d_model))
def encode(self, src):
# src shape: [batch_size, seq_len (196), input_dim (768)]
batch_size, seq_len, input_dim = src.shape
# Project input_dim (768) to d_model (512)
src = self.input_proj(src) # Shape: [batch_size, seq_len (196), d_model (512)]
src = self.pos_encoder(src) # Add positional encoding
# Latent initialization
latent = self.latent_init.repeat(batch_size, 1).unsqueeze(1) # Shape: [batch_size, 1, d_model]
# Cross-attend latent with input sequence
for i in range(self.num_layers):
latent, _ = self.attention_layers[i](latent, src, src)
# Project to latent dimension and normalize
latent = self.latent_proj(latent.squeeze(1)) # Shape: [batch_size, latent_dim]
latent = self.latent_norm(latent)
return latent
def decode(self, latent, seq_w, seq_h):
batch_size = latent.size(0)
target_seq_len = seq_w * seq_h
# Project latent_dim back to d_model
memory = self.latent_to_d_model(latent).unsqueeze(1) # Shape: [batch_size, 1, d_model]
# Target initialization
# Repeat the learned target initialization to match the target sequence length
tgt = self.tgt_init.repeat(batch_size, target_seq_len, 1) # Shape: [batch_size, target_seq_len, d_model]
# Apply positional encoding
tgt = self.pos_encoder(tgt)
# Apply transformer decoder
output = self.transformer_decoder(tgt, memory) # Shape: [batch_size, target_seq_len, d_model]
# Project to output_dim
output = self.output_proj(output) # Shape: [batch_size, target_seq_len, output_dim]
# Reshape output to (batch_size, seq_w, seq_h, output_dim)
output = output.view(batch_size, seq_w, seq_h, self.output_dim)
# Permute dimensions to (batch_size, output_dim, seq_w, seq_h)
output = output.permute(0, 3, 1, 2) # Shape: [batch_size, output_dim, seq_w, seq_h]
return output
def forward(self, src, seq_w, seq_h):
latent = self.encode(src)
output = self.decode(latent, seq_w, seq_h)
return output
def encode_to_base64(self, latent_vector, bits_per_element):
max_int = 2 ** bits_per_element - 1
q_latent = ((latent_vector + 1) * (max_int / 2)).clip(0, max_int).astype(np.uint8)
byte_array = q_latent.tobytes()
encoded_string = base64.b64encode(byte_array).decode('utf-8')
# Remove padding characters
return encoded_string.rstrip('=')
def decode_from_base64(self, encoded_string, bits_per_element, latentdim):
# Add back padding if it's missing
missing_padding = len(encoded_string) % 4
if missing_padding:
encoded_string += '=' * (4 - missing_padding)
byte_array = base64.b64decode(encoded_string)
q_latent = np.frombuffer(byte_array, dtype=np.uint8)[:latentdim]
max_int = 2 ** bits_per_element - 1
latent_vector = q_latent.astype(np.float32) * 2 / max_int - 1
return latent_vector
def forward_encoding(self, src, seq_w, seq_h):
"""
Encodes the input `src` into a latent representation, encodes it to a Base64 string,
decodes it back to the latent space, and then decodes it to the output.
Args:
src: The input data to encode.
Returns:
output: The decoded output from the latent representation.
"""
# Step 1: Encode the input to latent space
latent = self.encode(src) # latent is of shape (batch_size, self.latentdim)
batch_size, latentdim = latent.shape
# Ensure bits_per_element is appropriate
bits_per_element = int(120 / latentdim) # Example: latentdim = 20, bits_per_element = 6
if bits_per_element > 8:
raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
encoded_strings = []
# Step 2: Encode each latent vector to a Base64 string
for i in range(batch_size):
latent_vector = latent[i].cpu().numpy()
encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
encoded_strings.append(encoded_string)
decoded_latents = []
# Step 3: Decode each Base64 string back to the latent vector
for i, encoded_string in enumerate(encoded_strings):
print(encoded_string)
decoded_latent = self.decode_from_base64(encoded_string, bits_per_element, latentdim)
decoded_latents.append(decoded_latent)
# Step 4: Convert the list of decoded latents back to a tensor
decoded_latents = torch.tensor(decoded_latents, dtype=latent.dtype, device=latent.device)
# Step 5: Decode the latent tensor into the output
output = self.decode(decoded_latents,seq_w, seq_h)
return output, encoded_strings
def forward_from_stylecode (self, stylecode, seq_w, seq_h,dtyle,device):
latentdim = 20
bits_per_element = 6
decoded_latents = []
#for i, encoded_string in enumerate(stylecode):
decoded_latent = self.decode_from_base64(stylecode, bits_per_element, latentdim)
decoded_latents.append(decoded_latent)
# Step 4: Convert the list of decoded latents back to a tensor
decoded_latents = torch.tensor(decoded_latents, dtype=dtyle, device=device)
output = self.decode(decoded_latents, seq_w, seq_h)
return output
@torch.no_grad()
def make_stylecode (self,src):
src = src.to("cuda")
self = self.to("cuda")
print(src.device,self.device,self.input_proj.weight.device)
latent = self.encode(src) # latent is of shape (batch_size, self.latentdim)
batch_size, latentdim = latent.shape
# Ensure bits_per_element is appropriate
bits_per_element = int(120 / latentdim) # Example: latentdim = 20, bits_per_element = 6
if bits_per_element > 8:
raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
encoded_strings = []
# Step 2: Encode each latent vector to a Base64 string
for i in range(batch_size):
latent_vector = latent[i].cpu().numpy()
encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
encoded_strings.append(encoded_string)
return encoded_strings |