gemma-embed / train_stage_3.py
dejanseo's picture
Upload 54 files
464767c verified
#!/usr/bin/env python3
import os
# prevent HF tokenizers threads from hanging the process
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from multiprocessing import freeze_support
def main():
# --- Config ---
PRET_FILE = "pretokenized_queries.pt"
MODEL_NAME = "google/gemma-3-1b-pt"
LORA_DIR = "phase2_triplet_amp/final"
BATCH_SIZE = 64
LR = 1e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 1
TEMP = 0.05
OUTPUT_DIR = "phase3_self_contrast"
GRAD_CLIP_NORM = 1.0
SEED = 42
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
# --- Load pretokenized queries safely ---
data = torch.load(PRET_FILE, weights_only=True)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
dataset = TensorDataset(input_ids, attention_mask)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# --- Load base model + LoRA adapters ---
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
# --- Projection head ---
class GemmaSelfContrast(nn.Module):
def __init__(self, peft_model):
super().__init__()
self.peft = peft_model
hs = peft_model.base_model.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(hs, 512),
nn.ReLU(),
nn.Linear(512, 256),
)
def forward(self, ids, mask):
out = self.peft.base_model(
input_ids=ids,
attention_mask=mask,
output_hidden_states=True,
return_dict=True
)
h = out.hidden_states[-1].mean(dim=1)
h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
z = self.proj(h)
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
return z / norm
model = GemmaSelfContrast(peft).to(device)
# --- Optimizer, scheduler, AMP scaler ---
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
total_steps = len(loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)
scaler = GradScaler()
# --- Training loop ---
model.train()
for epoch in range(1, NUM_EPOCHS + 1):
total_loss = 0.0
for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"):
ids, mask = ids.to(device), mask.to(device)
with autocast():
e1 = model(ids, mask)
e2 = model(ids, mask)
emb = torch.cat([e1, e2], dim=0)
sim = (emb @ emb.T) / TEMP
# mask diagonal with -inf
mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
sim = sim.masked_fill(mask_eye, float('-inf'))
B = e1.size(0)
labels = torch.cat([
torch.arange(B, device=device) + B,
torch.arange(B, device=device)
], dim=0)
loss = F.cross_entropy(sim, labels)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch} avg loss: {avg_loss:.6f}")
# --- Save only LoRA adapters ---
final_dir = os.path.join(OUTPUT_DIR, "final")
os.makedirs(final_dir, exist_ok=True)
peft.save_pretrained(final_dir)
print("Phase 3 complete. LoRA adapters saved to", final_dir)
if __name__ == "__main__":
freeze_support()
main()
sys.exit(0)