APE / APE.py
Nitin2004's picture
Upload full model_ADPB folder
f881e8c verified
import os
import time
import math
import pickle
import random
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# We use Hugging Face’s transformers only for pretrained weight loading and tokenizer.
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from dataclasses import dataclass
# ----------------------------
# Helper: ALiBi slopes computation
# ----------------------------
def get_alibi_slopes(n_head):
"""Compute ALiBi slopes for each head.
This implementation follows the approach used in several ALiBi implementations.
"""
def get_slopes_power_of_2(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n_head).is_integer():
slopes = get_slopes_power_of_2(n_head)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
slopes = get_slopes_power_of_2(closest_power_of_2)
extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: n_head - closest_power_of_2]
slopes.extend(extra_slopes)
return torch.tensor(slopes, dtype=torch.float32)
# ----------------------------
# Model Components
# ----------------------------
class LayerNorm(nn.Module):
"""LayerNorm with an optional bias."""
def __init__(self, ndim, bias: bool):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.use_rope = config.use_rope
self.rope_base = config.rope_base
# Existing APE support.
self.use_ape = getattr(config, 'use_ape', False)
# New: ALiBi support.
self.use_alibi = getattr(config, 'use_alibi', False)
if self.use_alibi and self.use_ape:
raise ValueError("Cannot use both ALiBi and APE simultaneously.")
# For APE, learn a parameter beta.
if self.use_ape:
self.beta = nn.Parameter(torch.tensor(1.0))
# Use Flash Attention if available (but disable when APE is enabled).
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if (not self.flash) or self.use_ape:
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
"""
Args:
x: Input tensor [B, T, C]
return_attn_entropy (bool): If True, return attention entropy.
aggregate_heads (bool): If True, average entropy across heads.
Returns:
y: Output tensor [B, T, C] or (y, entropy)
"""
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_dim = C // self.n_head
# Reshape to [B, n_head, T, head_dim]
q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)
# Optionally, apply RoPE if enabled.
if self.use_rope:
hs = head_dim
d = hs // 2
if self.use_ape:
theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
else:
theta = 1.0 / (self.rope_base ** (2 * torch.arange(0, d, dtype=x.dtype, device=x.device) / hs))
t_pos = torch.arange(T, device=x.device, dtype=x.dtype)
freqs = torch.outer(t_pos, theta)
freqs_cos = torch.cos(freqs).unsqueeze(0).unsqueeze(0)
freqs_sin = torch.sin(freqs).unsqueeze(0).unsqueeze(0)
def apply_rope(tensor, cos, sin):
tensor = tensor.reshape(*tensor.shape[:-1], -1, 2)
x0 = tensor[..., 0]
x1 = tensor[..., 1]
x0_rot = x0 * cos - x1 * sin
x1_rot = x0 * sin + x1 * cos
return torch.stack([x0_rot, x1_rot], dim=-1).flatten(start_dim=-2)
q = apply_rope(q, freqs_cos, freqs_sin)
k = apply_rope(k, freqs_cos, freqs_sin)
# Compute scaled dot-product attention scores.
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(head_dim))
# --- Apply positional biases ---
if self.use_alibi:
slopes = get_alibi_slopes(self.n_head).to(x.device) # shape: (n_head,)
rel_positions = torch.arange(T, device=x.device).unsqueeze(0) - torch.arange(T, device=x.device).unsqueeze(1)
alibi_bias = slopes.view(1, self.n_head, 1, 1) * rel_positions.view(1, 1, T, T)
att = att - alibi_bias
elif self.use_ape:
pos_ids = torch.arange(T, device=x.device)
rel_dist = pos_ids.unsqueeze(0) - pos_ids.unsqueeze(1)
abs_rel = rel_dist.abs().float()
temp_matrix = 1.0 / (1.0 + abs_rel)
bias_matrix = - self.beta * torch.log(1.0 + abs_rel)
temp_matrix = temp_matrix.unsqueeze(0).unsqueeze(0)
bias_matrix = bias_matrix.unsqueeze(0).unsqueeze(0)
att = temp_matrix * att + bias_matrix
p_att = F.softmax(att, dim=-1)
entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1) # [B, n_head, T, T]
if self.flash and not self.use_ape:
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True
)
else:
if T > self.bias.size(-1):
bias = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
else:
bias = self.bias[:, :, :T, :T]
att = att.masked_fill(bias == 0, float('-inf'))
p_att = F.softmax(att, dim=-1)
entropy = -(p_att * torch.log(p_att + 1e-9)).sum(dim=-1)
att = self.attn_dropout(p_att)
y = att @ v # [B, n_head, T, head_dim]
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
if return_attn_entropy:
if aggregate_heads:
entropy = entropy.mean(dim=1) # [B, T, T]
return y, entropy
else:
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x, return_attn_entropy=False, aggregate_heads=False):
if return_attn_entropy:
attn_output, entropy = self.attn(self.ln_1(x), return_attn_entropy=True, aggregate_heads=aggregate_heads)
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return x, entropy
else:
attn_output = self.attn(self.ln_1(x), return_attn_entropy=False)
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 128
vocab_size: int = 50304 # For GPT-2
n_layer: int = 6
n_head: int = 6
n_embd: int = 384
dropout: float = 0.0
bias: bool = True
use_rope: bool = True
rope_base: float = 10000.0
use_ape: bool = False
lambda_temp: float = 0.1
use_alibi: bool = False
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None and config.block_size is not None
self.config = config
# If using ALiBi, disable RoPE.
self.use_rope = config.use_rope and not config.use_alibi
print(f"Using RoPE in GPT init: {self.use_rope}")
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = None if self.use_rope else nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding and (not self.use_rope) and (self.transformer.wpe is not None):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None, return_attn_entropy=False, aggregate_heads=False):
device = idx.device
b, t = idx.size()
pos = torch.arange(0, t, dtype=torch.long, device=device)
tok_emb = self.transformer.wte(idx)
if self.use_rope or self.config.use_alibi:
x = self.transformer.drop(tok_emb)
else:
pos_emb = self.transformer.wpe(pos) if self.transformer.wpe is not None else 0
x = self.transformer.drop(tok_emb + pos_emb)
attn_entropies = []
for block in self.transformer.h:
if return_attn_entropy:
x, entropy = block(x, return_attn_entropy=True, aggregate_heads=aggregate_heads)
attn_entropies.append(entropy)
else:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
if return_attn_entropy:
return logits, loss, attn_entropies
else:
return logits, loss
@torch.no_grad()
def generate_and_compute_perplexity(self, prompt, ground_truth, temperature=1.0, return_attn_entropy=False, aggregate_heads=False):
if return_attn_entropy:
_, _, attn_entropies = self(prompt, return_attn_entropy=True, aggregate_heads=aggregate_heads)
per_layer_avgs = [entropy.mean().item() for entropy in attn_entropies]
avg_entropy = np.mean(per_layer_avgs)
else:
avg_entropy = None
total_loss = 0.0
total_tokens = 0
prompt_length = prompt.size(1)
num_target_tokens = ground_truth.size(1) - prompt_length
idx = prompt.clone()
for i in range(num_target_tokens):
logits, _ = self(idx)
logits = logits[:, -1, :] / temperature
target = ground_truth[:, prompt_length + i]
loss = F.cross_entropy(logits, target, reduction='sum')
total_loss += loss.item()
total_tokens += target.numel()
target_token = target.unsqueeze(1)
idx = torch.cat((idx, target_token), dim=1)
avg_neg_log_likelihood = total_loss / total_tokens if total_tokens > 0 else float('inf')
perplexity = math.exp(avg_neg_log_likelihood)
return idx, perplexity, avg_entropy
@torch.no_grad()
def generate_until_end(self, idx, temperature=1.0, top_k=None, max_new_tokens=1000):
for i in range(max_new_tokens):
idx_cond = idx
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
if idx_next.item() == 50256:
break
return idx
# ----------------------------
# Utility Functions for Training & Evaluation
# ----------------------------
# Data Loader Functions
train_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/train_tinystories.bin"
val_data_path = "/data1/home/nitinvetcha/Topics in AI/Streamlined/COLM2025/val_tinystories.bin"
def get_batch(split):
data_path = train_data_path if split == 'train' else val_data_path
data = np.memmap(data_path, dtype=np.uint16, mode='r')
total_tokens = len(data)
max_ix = max(1, total_tokens - gptconf.block_size)
ix = torch.randint(0, max_ix, (batch_size,))
X = torch.stack([torch.from_numpy(data[i:i+gptconf.block_size].astype(np.int64)) for i in ix])
Y = torch.stack([torch.from_numpy(data[i+1:i+1+gptconf.block_size].astype(np.int64)) for i in ix])
return X.to(device), Y.to(device)
def evaluate_prompt_perplexity(model, token_file, prompt_length, num_trials, generation_params, device):
tokens = np.fromfile(token_file, dtype=np.uint16)
total_tokens = len(tokens)
perplexities = []
entropy_trials = []
max_new_tokens = generation_params.get("max_new_tokens", 50)
total_length = prompt_length + max_new_tokens
for trial in range(num_trials):
start_idx = random.randint(0, total_tokens - total_length)
sequence_tokens = tokens[start_idx : start_idx + total_length]
prompt_tokens = sequence_tokens[:prompt_length]
ground_truth_tokens = sequence_tokens
prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long).unsqueeze(0).to(device)
ground_truth_tensor = torch.tensor(ground_truth_tokens, dtype=torch.long).unsqueeze(0).to(device)
_, ppl, trial_entropy = model.generate_and_compute_perplexity(
prompt_tensor, ground_truth_tensor,
temperature=generation_params.get("temperature", 1.0),
return_attn_entropy=True, aggregate_heads=True
)
perplexities.append(ppl)
entropy_trials.append(trial_entropy)
print(f"Trial {trial+1}/{num_trials} for prompt length {prompt_length}: Perplexity = {ppl:.2f}, Avg Entropy = {trial_entropy:.4f}")
avg_ppl = np.mean(perplexities)
avg_entropy = np.mean(entropy_trials)
print(f"Prompt Length {prompt_length} - Avg Perplexity: {avg_ppl:.2f}, Avg Attention Entropy: {avg_entropy:.4f}\n")
return avg_ppl, avg_entropy
# ----------------------------
# Training Loop
# ----------------------------
# Training hyperparameters
batch_size = 12
max_iters = 25001
save_interval = 5000
learning_rate = 6e-4
weight_decay = 1e-1
grad_clip = 1.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)
# Model configuration: adjust these flags as needed.
model_args = dict(
n_layer=6,
n_head=6,
n_embd=384,
block_size=64, # You can change this as needed.
bias=False,
use_rope=True,
use_ape=True, # Set to True if you want APE.
use_alibi=False, # Set to True to use ALiBi.
rope_base=10000.0,
vocab_size=50304,
dropout=0.0
)
gptconf = GPTConfig(**model_args)
model = GPT(gptconf).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
iter_num = 0
start_time = time.time()
training_losses = []
validation_losses = []
save_iters = []
# Build a flag string for naming: e.g. "rope_ape" or "alibi" etc.
flag_parts = []
if gptconf.use_rope:
flag_parts.append("rope")
if gptconf.use_ape:
flag_parts.append("ape")
if gptconf.use_alibi:
flag_parts.append("alibi")
flag_str = "_".join(flag_parts) if flag_parts else "none"
weight_dir = f"weights_{flag_str}_{gptconf.block_size}"
os.makedirs(weight_dir, exist_ok=True)
while iter_num < max_iters:
X_train, Y_train = get_batch('train')
optimizer.zero_grad()
logits, loss_train = model(X_train, Y_train)
loss_train.backward()
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
training_losses.append(loss_train.item())
model.eval()
X_val, Y_val = get_batch('val')
with torch.no_grad():
logits_val, loss_val = model(X_val, Y_val)
validation_losses.append(loss_val.item())
model.train()
if iter_num % 100 == 0:
elapsed = time.time() - start_time
print(f"Iter {iter_num:5d}: train loss = {loss_train.item():.4f}, val loss = {loss_val.item():.4f}, time/iter = {elapsed/(iter_num+1):.4f}s")
if iter_num > 0 and iter_num % save_interval == 0:
save_iters.append(iter_num)
ckpt = {
'iter_num': iter_num,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'training_losses': training_losses,
'validation_losses': validation_losses,
'save_iters': save_iters,
}
ckpt_path = os.path.join(weight_dir, f"ckpt_{iter_num}.pt")
torch.save(ckpt, ckpt_path)
print(f"Checkpoint saved to {ckpt_path}")
iter_num += 1
print("Training complete.")
plt.figure(figsize=(10, 6))
plt.plot(range(len(training_losses)), training_losses, label="Training Loss")
plt.plot(range(len(validation_losses)), validation_losses, label="Validation Loss", alpha=0.7)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Iteration")
plt.legend()
plt.grid(True)
plt.show()
# ----------------------------
# Perplexity & Entropy Evaluation
# ----------------------------
token_file = val_data_path # Use validation data for evaluation.
prompt_lengths = [64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_trials = 5
generation_params = {"temperature": 1.0, "max_new_tokens": 50}
avg_perplexities = []
avg_entropies = []
for pl in prompt_lengths:
print(f"Evaluating for prompt length: {pl}")
avg_ppl, avg_entropy = evaluate_prompt_perplexity(model, token_file, pl, num_trials, generation_params, device)
avg_perplexities.append(avg_ppl)
avg_entropies.append(avg_entropy)
results = {
"prompt_lengths": prompt_lengths,
"avg_perplexities": avg_perplexities,
"avg_entropies": avg_entropies
}
results_filename = f"results_{flag_str}_{gptconf.block_size}.json"
with open(results_filename, "w") as f:
json.dump(results, f)
print(f"Results saved to {results_filename}")
plt.figure(figsize=(8, 6))
plt.plot(prompt_lengths, avg_perplexities, marker='o')
plt.xlabel("Prompt Length")
plt.ylabel("Avg Generated Perplexity")
plt.title("Avg Generated Perplexity vs Prompt Length")
plt.grid(True)
plt.xscale('log')
plt.savefig(f"avg_generated_perplexity_{flag_str}_{gptconf.block_size}.png")
plt.show()
plt.figure(figsize=(8, 6))
plt.plot(prompt_lengths, avg_entropies, marker='o', color='red')
plt.xlabel("Prompt Length")
plt.ylabel("Avg Attention Entropy")
plt.title("Avg Attention Entropy vs Prompt Length\n(Averaged over Layers)")
plt.grid(True)
plt.xscale('log')
plt.savefig(f"avg_attention_entropy_{flag_str}_{gptconf.block_size}.png")
plt.show()