| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoConfig |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel |
|
|
| class ResidualBlock(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim * 2), |
| nn.GELU(), |
| nn.Linear(dim * 2, dim) |
| ) |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
| return self.norm(x + self.net(x)) |
|
|
| class ResidualAutoencoder(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| |
| |
| print(f"Loading Encoder: {cfg.encoder_name}...") |
| self.encoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True) |
| self.hidden_dim = self.encoder.config.hidden_size |
| for p in self.encoder.parameters(): p.requires_grad = False |
| |
| |
| |
| |
| self.compressor = nn.Sequential( |
| nn.Linear(self.hidden_dim, self.hidden_dim), |
| ResidualBlock(self.hidden_dim), |
| |
| ) |
| |
| self.decompressor = nn.Sequential( |
| ResidualBlock(self.hidden_dim), |
| |
| nn.Linear(self.hidden_dim, self.hidden_dim) |
| ) |
|
|
| |
| print(f"Loading Decoder: {cfg.encoder_name}...") |
| self.decoder = AutoModel.from_pretrained(cfg.encoder_name, trust_remote_code=True) |
| self.decoder.config.is_decoder = False |
| |
| |
| self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False) |
| with torch.no_grad(): |
| self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight) |
| self.lm_head.weight.requires_grad = True |
|
|
| def encode(self, input_ids, attention_mask): |
| with torch.no_grad(): |
| enc_out = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state |
| return self.compressor(enc_out) |
|
|
| def decode(self, z, attention_mask): |
| h = self.decompressor(z) |
| dec_out = self.decoder(inputs_embeds=h, attention_mask=attention_mask).last_hidden_state |
| return self.lm_head(dec_out) |
|
|
| def forward(self, input_ids, attention_mask): |
| z = self.encode(input_ids, attention_mask) |
| logits = self.decode(z, attention_mask) |
| return logits, z |
|
|
| class ReshapedAutoencoder(nn.Module): |
| """ |
| Sequence-to-Sequence Autoencoder with Spherical Latent Space. |
| Logic: Token -> Jina -> Linear -> Linear -> Decoder -> Token |
| """ |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
|
|
| self.latent_scale = getattr(cfg,"latent_scale",10.0) |
| |
| |
| print(f"Loading Pretrained Encoder: {cfg.encoder_name}...") |
| |
| self.encoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True) |
| self.hidden_dim = self.encoder.config.hidden_size |
| self.vocab_size = self.encoder.config.vocab_size |
| |
| |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
| |
| |
| |
| self.compress = nn.Sequential( |
| nn.Linear(self.hidden_dim, cfg.latent_dim), |
| nn.GELU(), |
| nn.Linear(cfg.latent_dim, cfg.latent_dim), |
| nn.LayerNorm(cfg.latent_dim) |
| ) |
| |
| |
| self.decompress = nn.Sequential( |
| nn.Linear(cfg.latent_dim, self.hidden_dim), |
| nn.GELU(), |
| nn.Linear(self.hidden_dim, self.hidden_dim), |
| nn.LayerNorm(self.hidden_dim) |
| ) |
|
|
| |
| |
| print(f"Loading Pretrained Decoder: {cfg.encoder_name}...") |
| |
| self.decoder = AutoModel.from_pretrained(cfg.encoder_name,trust_remote_code=True) |
| |
| |
| |
| |
| self.decoder.config.is_decoder = False |
| |
| |
| |
| self.lm_head = nn.Linear(self.hidden_dim, self.encoder.config.vocab_size, bias=False) |
| with torch.no_grad(): |
| self.lm_head.weight.copy_(self.encoder.embeddings.word_embeddings.weight) |
| |
| self.lm_head.weight.requires_grad = True |
|
|
| def encode(self, input_ids, attention_mask): |
| """ |
| Input: [B, L] |
| Output: [B, L, Latent_Dim] |
| """ |
| with torch.no_grad(): |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| |
| |
| z = self.compress(outputs.last_hidden_state) |
|
|
| |
| |
| |
| return z |
|
|
| |
| def decode(self, latents,attention_mask=None): |
| """ |
| Input: [B, L, Latent_Dim] |
| Output: [B, L, Vocab] |
| """ |
| |
| |
| |
| hidden = self.decompress(latents) |
| |
| |
| |
| decoder_outputs = self.decoder( |
| inputs_embeds=hidden, |
| attention_mask=attention_mask |
| ) |
|
|
| sequence_output = decoder_outputs.last_hidden_state |
|
|
| |
| return self.lm_head(sequence_output) |
|
|
| |
| |
| |
| |
| def forward(self, input_ids, encoder_mask, decoder_mask=None): |
| if decoder_mask is None: |
| decoder_mask = encoder_mask |
| z = self.encode(input_ids, encoder_mask) |
| logits = self.decode(z, attention_mask=decoder_mask) |
| return logits, z |