ExtLWM / lwm_model.py
Login2025's picture
Upload 5 files
53ca419 verified
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 13 19:23:54 2024
This script defines the LWM model architecture.
@author: Sadjad Alikhani
"""
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
def create_dataloader(grouped_data, batch_size, shuffle, generator=None):
dataloaders = {}
for seq_length, group in grouped_data.items():
print(f"dataloader in progress ...\nkey: {seq_length}")
## Uncomment the following line if you run out of memory during pre-training
# batch_size = batch_size // 8 if seq_length >= 5 else batch_size
# Unpack samples for the current group
input_ids, masked_tokens, masked_pos = zip(*group)
# Convert to tensors
input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
# Create TensorDataset and DataLoader
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True, generator=generator)
return dataloaders
def lwm_tokenizer(manual_data, patch_rows, patch_cols, masking_percent=.40, mask=False, mask_pos=None, seed=42, ):
patches = [patch_maker(np.array(manual_data), patch_rows, patch_cols)]
patches = [patch for patch_list in patches for patch in patch_list] # list(Batch)
grouped_data = defaultdict(list) # Group samples by sequence length
grouped_data_2 = []
# for user_idx in tqdm(range(len(patches)), desc="Processing items"):
for user_idx in range(len(patches)):
patch_size = patches[user_idx].shape[1]
n_patches = patches[user_idx].shape[0]
n_masks_half = int(masking_percent * n_patches)
word2id = {
'[CLS]': 0.2 * np.ones((patch_size)),
'[MASK]': 0.1 * np.ones((patch_size))
}
sample = make_sample(
user_idx, patches, word2id, n_patches, n_masks_half, mask_pos, mask=mask, seed=seed
)
if mask:
seq_length = len(sample[0])
grouped_data[seq_length].append(sample)
else:
grouped_data_2.append(sample)
if mask:
# Normalize keys to 0, 1, 2, ...
normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
else:
normalized_grouped_data = torch.stack(grouped_data_2, dim=0)
return normalized_grouped_data
def make_sample(user_idx, patch, word2id, n_patches, n_masks, mask_pos=None, mask=True, seed=None):
if seed is not None:
np.random.seed(seed)
# Step 1: Retrieve tokens and prepend [CLS]
tokens = patch[user_idx]
input_ids = np.vstack((word2id['[CLS]'], tokens))
# Step 2: Mask real and imaginary patches
tokens_size = int(n_patches) # int(n_patches / 2)
if mask_pos is not None:
masked_pos = mask_pos
else:
masked_pos = np.random.choice(range(1, tokens_size+1), size=n_masks, replace=False)
masked_tokens = []
for pos in masked_pos:
original_masked_tokens = input_ids[pos].copy()
masked_tokens.append(original_masked_tokens)
if mask:
input_ids[pos] = word2id['[MASK]']
# rnd_num = np.random.rand()
# if rnd_num < 0.1:
# input_ids[pos] = np.random.rand(32) # Replace with random values
# elif rnd_num < 0.9:
# input_ids[pos] = word2id['[MASK]'] # Replace with [MASK]
if mask:
return [input_ids, masked_tokens, masked_pos]
else:
return torch.tensor(input_ids)
def patch_maker(original_ch, patch_rows=1, patch_cols=16):
n_samples, n_rows, n_cols = original_ch.shape
# Step 2: Split into real and imaginary parts and interleave them
flat_real = original_ch.real
flat_imag = original_ch.imag
# Interleave real and imaginary parts along the last axis
interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
interleaved[:, :, 0::2] = flat_real
interleaved[:, :, 1::2] = flat_imag
# Step 3: Compute the number of patches along rows and columns
n_patches_rows = int(np.ceil(n_rows / patch_rows))
n_patches_cols = int(np.ceil(n_cols / patch_cols))
# Step 4: Pad the matrix if necessary to make it divisible by patch size
padded_rows = n_patches_rows * patch_rows - n_rows
padded_cols = n_patches_cols * patch_cols - n_cols
if padded_rows > 0 or padded_cols > 0:
interleaved = np.pad(
interleaved,
((0, 0), (0, padded_rows), (0, padded_cols * 2)), # Double padding for interleaved axis
mode='constant',
constant_values=0,
)
# Step 5: Create patches by dividing into blocks
n_samples, padded_rows, padded_cols = interleaved.shape
padded_cols //= 2 # Adjust for interleaving (real and imaginary parts count as one)
patches = []
for i in range(0, padded_rows, patch_rows):
for j in range(0, padded_cols, patch_cols):
patch = interleaved[:, i:i + patch_rows, j * 2:(j + patch_cols) * 2]
patches.append(patch.reshape(n_samples, -1)) # Flatten each patch
# Step 6: Stack patches to form the final array
patches = np.stack(patches, axis=1) # Shape: (num_samples, n_patches, patch_rows * patch_cols * 2)
# nor_patches = patches
nor_patches = patches*1e6
return nor_patches
#%%
class LayerNormalization(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.alpha * (x - mean) / (std + self.eps) + self.bias
class Embedding(nn.Module):
def __init__(self, element_length, d_model, max_len=513):
super().__init__()
self.element_length = element_length
self.d_model = d_model
self.proj = nn.Linear(element_length, d_model)
self.pos_embed = nn.Embedding(max_len, d_model)
self.norm = LayerNormalization(d_model)
def forward(self, x):
seq_len = x.size(1)
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
pos_encodings = self.pos_embed(pos)
tok_emb = self.proj(x.float())
embedding = tok_emb + pos_encodings
return self.norm(embedding)
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super().__init__()
self.d_k = d_k
def forward(self, Q, K, V):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
context = torch.matmul(attn, V)
return context, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout):
super().__init__()
self.d_k = d_model // n_heads
self.d_v = d_model // n_heads
self.n_heads = n_heads
self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
self.W_K = nn.Linear(d_model, self.d_k * n_heads)
self.W_V = nn.Linear(d_model, self.d_v * n_heads)
self.linear = nn.Linear(n_heads * self.d_v, d_model)
self.dropout = nn.Dropout(dropout)
self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V):
residual, batch_size = Q, Q.size(0)
q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
output = self.linear(output)
return residual + self.dropout(output), attn
class PoswiseFeedForwardNet(nn.Module):
def __init__(self, d_model, d_ff, dropout):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.fc2(self.dropout(F.relu(self.fc1(x))))
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout):
super().__init__()
self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
self.norm1 = LayerNormalization(d_model)
self.norm2 = LayerNormalization(d_model)
def forward(self, enc_inputs):
# Self-Attention with Add & Norm
attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
attn_outputs = self.norm1(enc_inputs + attn_outputs) # Add & Norm
# Feed-Forward with Add & Norm
ff_outputs = self.pos_ffn(attn_outputs)
enc_outputs = self.norm2(attn_outputs + ff_outputs) # Add & Norm
return enc_outputs, attn
class lwm(nn.Module):
def __init__(self, element_length=32, d_model=128, n_layers=12, max_len=321, n_heads=8, dropout=0.1):
super().__init__()
self.embedding = Embedding(element_length, d_model, max_len)
self.layers = nn.ModuleList(
[EncoderLayer(d_model, n_heads, d_model*4, dropout) for _ in range(n_layers)]
)
self.linear = nn.Linear(d_model, d_model)
self.norm = LayerNormalization(d_model)
embed_weight = self.embedding.proj.weight
_, n_dim = embed_weight.size()
self.decoder = nn.Linear(d_model, n_dim, bias=False)
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
@classmethod
def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
model = cls().to(device)
model.load_state_dict(torch.load(ckpt_name, map_location=device))
print(f"Model loaded successfully from {ckpt_name}")
return model
def forward(self, input_ids, masked_pos=None):
# Step 1: Embedding
output = self.embedding(input_ids)
attention_maps = []
# Step 2: Pass through Encoder Layers
for layer in self.layers:
output, attn = layer(output)
attention_maps.append(attn)
# If masked_pos is provided, perform masked token prediction
if masked_pos is not None:
masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
h_masked = torch.gather(output, 1, masked_pos)
h_masked = self.norm(F.relu(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias
return logits_lm, output, attention_maps
else:
return output, attention_maps