File size: 9,510 Bytes
934bde2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import datetime
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.normalization import GroupNorm
import base64
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class AttentionAutoencoder(nn.Module):
    def __init__(self, input_dim=768,output_dim=1280, d_model=512, latent_dim=20, seq_len=196, num_heads=4, num_layers=3, out_intermediate=512):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.input_dim = input_dim  # Adjusted to 768
        self.d_model = d_model
        self.latent_dim = latent_dim
        self.seq_len = seq_len  # Adjusted to 196
        self.out_intermediate = out_intermediate
        self.output_dim = output_dim

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Input Projection (adjusted to project from input_dim=768 to d_model=512)
        self.input_proj = nn.Linear(input_dim, d_model)

        # Latent Initialization
        self.latent_init = nn.Parameter(torch.randn(1, d_model))

        # Cross-Attention Encoder
        self.num_layers = num_layers
        self.attention_layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
            for _ in range(num_layers)
        ])

        # Latent Space Refinement
        self.latent_proj = nn.Linear(d_model, latent_dim)
        self.latent_norm = nn.LayerNorm(latent_dim)
        self.latent_to_d_model = nn.Linear(latent_dim, d_model)

        # Mapping latent to intermediate feature map
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, batch_first=True),
            num_layers=2
        )

        # Output projection
        self.output_proj = nn.Linear(d_model, output_dim)
        self.tgt_init = nn.Parameter(torch.randn(1, d_model))



    def encode(self, src):
        # src shape: [batch_size, seq_len (196), input_dim (768)]
        batch_size, seq_len, input_dim = src.shape
    
        # Project input_dim (768) to d_model (512)
        src = self.input_proj(src)  # Shape: [batch_size, seq_len (196), d_model (512)]
        src = self.pos_encoder(src)  # Add positional encoding
        
        # Latent initialization
        latent = self.latent_init.repeat(batch_size, 1).unsqueeze(1)  # Shape: [batch_size, 1, d_model]
        
        # Cross-attend latent with input sequence
        for i in range(self.num_layers):
            latent, _ = self.attention_layers[i](latent, src, src)
        
        # Project to latent dimension and normalize
        latent = self.latent_proj(latent.squeeze(1))  # Shape: [batch_size, latent_dim]
        latent = self.latent_norm(latent)
    
        return latent
    
    def decode(self, latent, seq_w, seq_h):
        batch_size = latent.size(0)
        
        target_seq_len = seq_w * seq_h

        # Project latent_dim back to d_model
        memory = self.latent_to_d_model(latent).unsqueeze(1)  # Shape: [batch_size, 1, d_model]

        # Target initialization
        # Repeat the learned target initialization to match the target sequence length
        tgt = self.tgt_init.repeat(batch_size, target_seq_len, 1)  # Shape: [batch_size, target_seq_len, d_model]

        # Apply positional encoding
        tgt = self.pos_encoder(tgt)

        # Apply transformer decoder
        output = self.transformer_decoder(tgt, memory)  # Shape: [batch_size, target_seq_len, d_model]

        # Project to output_dim
        output = self.output_proj(output)  # Shape: [batch_size, target_seq_len, output_dim]

        # Reshape output to (batch_size, seq_w, seq_h, output_dim)
        output = output.view(batch_size, seq_w, seq_h, self.output_dim)

        # Permute dimensions to (batch_size, output_dim, seq_w, seq_h)
        output = output.permute(0, 3, 1, 2)  # Shape: [batch_size, output_dim, seq_w, seq_h]

        return output
    
    def forward(self, src, seq_w, seq_h):
        latent = self.encode(src)
        output = self.decode(latent, seq_w, seq_h)
        return output

    def encode_to_base64(self, latent_vector, bits_per_element):
        max_int = 2 ** bits_per_element - 1
        q_latent = ((latent_vector + 1) * (max_int / 2)).clip(0, max_int).astype(np.uint8)
        byte_array = q_latent.tobytes()
        encoded_string = base64.b64encode(byte_array).decode('utf-8')
        # Remove padding characters
        return encoded_string.rstrip('=')

    def decode_from_base64(self, encoded_string, bits_per_element, latentdim):
   
        # Add back padding if it's missing
        missing_padding = len(encoded_string) % 4
        if missing_padding:
            encoded_string += '=' * (4 - missing_padding)
        byte_array = base64.b64decode(encoded_string)
        q_latent = np.frombuffer(byte_array, dtype=np.uint8)[:latentdim]
        max_int = 2 ** bits_per_element - 1
        latent_vector = q_latent.astype(np.float32) * 2 / max_int - 1
        return latent_vector

    def forward_encoding(self, src, seq_w, seq_h):
        """

        Encodes the input `src` into a latent representation, encodes it to a Base64 string,

        decodes it back to the latent space, and then decodes it to the output.

        

        Args:

            src: The input data to encode.

        

        Returns:

            output: The decoded output from the latent representation.

        """
        # Step 1: Encode the input to latent space
        latent = self.encode(src)  # latent is of shape (batch_size, self.latentdim)
        batch_size, latentdim = latent.shape
        
        # Ensure bits_per_element is appropriate
        bits_per_element = int(120 / latentdim)  # Example: latentdim = 20, bits_per_element = 6
        if bits_per_element > 8:
            raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
        
        encoded_strings = []
        
        # Step 2: Encode each latent vector to a Base64 string
        for i in range(batch_size):
            latent_vector = latent[i].cpu().numpy()
            encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
            encoded_strings.append(encoded_string)
        
        decoded_latents = []
        
        # Step 3: Decode each Base64 string back to the latent vector
        for i, encoded_string in enumerate(encoded_strings):
            print(encoded_string)
            decoded_latent = self.decode_from_base64(encoded_string, bits_per_element, latentdim)
            decoded_latents.append(decoded_latent)
        
        # Step 4: Convert the list of decoded latents back to a tensor
        decoded_latents = torch.tensor(decoded_latents, dtype=latent.dtype, device=latent.device)
        
        # Step 5: Decode the latent tensor into the output
        output = self.decode(decoded_latents,seq_w, seq_h)
        
        return output, encoded_strings
    
    def forward_from_stylecode (self, stylecode, seq_w, seq_h,dtyle,device):

        latentdim = 20
        bits_per_element = 6
        decoded_latents = []

        
        #for i, encoded_string in enumerate(stylecode):
        decoded_latent = self.decode_from_base64(stylecode, bits_per_element, latentdim)
        decoded_latents.append(decoded_latent)
        
        # Step 4: Convert the list of decoded latents back to a tensor
        decoded_latents = torch.tensor(decoded_latents, dtype=dtyle, device=device)

        output = self.decode(decoded_latents, seq_w, seq_h)
        return output
    
    @torch.no_grad()
    def make_stylecode (self,src):
        src = src.to("cuda")
        self = self.to("cuda")
        print(src.device,self.device,self.input_proj.weight.device)
        latent = self.encode(src)  # latent is of shape (batch_size, self.latentdim)
        batch_size, latentdim = latent.shape
        
        # Ensure bits_per_element is appropriate
        bits_per_element = int(120 / latentdim)  # Example: latentdim = 20, bits_per_element = 6
        if bits_per_element > 8:
            raise ValueError("bits_per_element cannot exceed 8 when using uint8 for encoding.")
        
        encoded_strings = []
        
        # Step 2: Encode each latent vector to a Base64 string
        for i in range(batch_size):
            latent_vector = latent[i].cpu().numpy()
            encoded_string = self.encode_to_base64(latent_vector, bits_per_element)
            encoded_strings.append(encoded_string)
        return encoded_strings