File size: 4,524 Bytes
464767c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
import os
# prevent HF tokenizers threads from hanging the process
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from torch.cuda.amp import GradScaler, autocast
from tqdm.auto import tqdm
from multiprocessing import freeze_support

def main():
    # --- Config ---
    PRET_FILE      = "pretokenized_queries.pt"
    MODEL_NAME     = "google/gemma-3-1b-pt"
    LORA_DIR       = "phase2_triplet_amp/final"
    BATCH_SIZE     = 64
    LR             = 1e-5
    WEIGHT_DECAY   = 0.01
    NUM_EPOCHS     = 1
    TEMP           = 0.05
    OUTPUT_DIR     = "phase3_self_contrast"
    GRAD_CLIP_NORM = 1.0
    SEED           = 42

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(SEED)

    # --- Load pretokenized queries safely ---
    data = torch.load(PRET_FILE, weights_only=True)
    input_ids      = data["input_ids"]
    attention_mask = data["attention_mask"]
    dataset = TensorDataset(input_ids, attention_mask)
    loader  = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # --- Load base model + LoRA adapters ---
    base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
    peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)

    # --- Projection head ---
    class GemmaSelfContrast(nn.Module):
        def __init__(self, peft_model):
            super().__init__()
            self.peft = peft_model
            hs = peft_model.base_model.config.hidden_size
            self.proj = nn.Sequential(
                nn.Linear(hs, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
            )
        def forward(self, ids, mask):
            out = self.peft.base_model(
                input_ids=ids,
                attention_mask=mask,
                output_hidden_states=True,
                return_dict=True
            )
            h = out.hidden_states[-1].mean(dim=1)
            h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
            z = self.proj(h)
            z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
            norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
            return z / norm

    model = GemmaSelfContrast(peft).to(device)

    # --- Optimizer, scheduler, AMP scaler ---
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    total_steps = len(loader) * NUM_EPOCHS
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )
    scaler = GradScaler()

    # --- Training loop ---
    model.train()
    for epoch in range(1, NUM_EPOCHS + 1):
        total_loss = 0.0
        for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"):
            ids, mask = ids.to(device), mask.to(device)

            with autocast():
                e1  = model(ids, mask)
                e2  = model(ids, mask)
                emb = torch.cat([e1, e2], dim=0)
                sim = (emb @ emb.T) / TEMP
                # mask diagonal with -inf
                mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
                sim = sim.masked_fill(mask_eye, float('-inf'))

                B = e1.size(0)
                labels = torch.cat([
                    torch.arange(B, device=device) + B,
                    torch.arange(B, device=device)
                ], dim=0)
                loss = F.cross_entropy(sim, labels)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch} avg loss: {avg_loss:.6f}")

    # --- Save only LoRA adapters ---
    final_dir = os.path.join(OUTPUT_DIR, "final")
    os.makedirs(final_dir, exist_ok=True)
    peft.save_pretrained(final_dir)
    print("Phase 3 complete. LoRA adapters saved to", final_dir)

if __name__ == "__main__":
    freeze_support()
    main()
    sys.exit(0)