|
|
|
import os |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import TensorDataset, DataLoader |
|
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup |
|
from peft import PeftModel |
|
from torch.cuda.amp import GradScaler, autocast |
|
from tqdm.auto import tqdm |
|
from multiprocessing import freeze_support |
|
|
|
def main(): |
|
|
|
PRET_FILE = "pretokenized_queries.pt" |
|
MODEL_NAME = "google/gemma-3-1b-pt" |
|
LORA_DIR = "phase2_triplet_amp/final" |
|
BATCH_SIZE = 64 |
|
LR = 1e-5 |
|
WEIGHT_DECAY = 0.01 |
|
NUM_EPOCHS = 1 |
|
TEMP = 0.05 |
|
OUTPUT_DIR = "phase3_self_contrast" |
|
GRAD_CLIP_NORM = 1.0 |
|
SEED = 42 |
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.manual_seed(SEED) |
|
|
|
|
|
data = torch.load(PRET_FILE, weights_only=True) |
|
input_ids = data["input_ids"] |
|
attention_mask = data["attention_mask"] |
|
dataset = TensorDataset(input_ids, attention_mask) |
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager") |
|
peft = PeftModel.from_pretrained(base, LORA_DIR).to(device) |
|
|
|
|
|
class GemmaSelfContrast(nn.Module): |
|
def __init__(self, peft_model): |
|
super().__init__() |
|
self.peft = peft_model |
|
hs = peft_model.base_model.config.hidden_size |
|
self.proj = nn.Sequential( |
|
nn.Linear(hs, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, 256), |
|
) |
|
def forward(self, ids, mask): |
|
out = self.peft.base_model( |
|
input_ids=ids, |
|
attention_mask=mask, |
|
output_hidden_states=True, |
|
return_dict=True |
|
) |
|
h = out.hidden_states[-1].mean(dim=1) |
|
h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6) |
|
z = self.proj(h) |
|
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) |
|
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) |
|
return z / norm |
|
|
|
model = GemmaSelfContrast(peft).to(device) |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) |
|
total_steps = len(loader) * NUM_EPOCHS |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=int(0.1 * total_steps), |
|
num_training_steps=total_steps |
|
) |
|
scaler = GradScaler() |
|
|
|
|
|
model.train() |
|
for epoch in range(1, NUM_EPOCHS + 1): |
|
total_loss = 0.0 |
|
for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"): |
|
ids, mask = ids.to(device), mask.to(device) |
|
|
|
with autocast(): |
|
e1 = model(ids, mask) |
|
e2 = model(ids, mask) |
|
emb = torch.cat([e1, e2], dim=0) |
|
sim = (emb @ emb.T) / TEMP |
|
|
|
mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool) |
|
sim = sim.masked_fill(mask_eye, float('-inf')) |
|
|
|
B = e1.size(0) |
|
labels = torch.cat([ |
|
torch.arange(B, device=device) + B, |
|
torch.arange(B, device=device) |
|
], dim=0) |
|
loss = F.cross_entropy(sim, labels) |
|
|
|
optimizer.zero_grad() |
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
scheduler.step() |
|
|
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(loader) |
|
print(f"Epoch {epoch} avg loss: {avg_loss:.6f}") |
|
|
|
|
|
final_dir = os.path.join(OUTPUT_DIR, "final") |
|
os.makedirs(final_dir, exist_ok=True) |
|
peft.save_pretrained(final_dir) |
|
print("Phase 3 complete. LoRA adapters saved to", final_dir) |
|
|
|
if __name__ == "__main__": |
|
freeze_support() |
|
main() |
|
sys.exit(0) |
|
|