gemma-embed / train_stage_1.py
dejanseo's picture
Upload 54 files
464767c verified
#!/usr/bin/env python3
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from tqdm.auto import tqdm
from multiprocessing import freeze_support
def main():
# Config
MODEL_NAME = "google/gemma-3-1b-pt"
DATA_FILE = "text.txt" # one sequence per line
BATCH_SIZE = 12
MAX_LENGTH = 128
LR = 1e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 1
VAL_RATIO = 0.1 # 10% for validation
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.0
PROJ_HIDDEN = 512
PROJ_OUT = 256
TEMP = 0.05
OUTPUT_DIR = "stage1_simcse"
GRAD_CLIP_NORM = 1.0
SIM_CLAMP_MIN = -10.0
SIM_CLAMP_MAX = 10.0
SEED = 42
os.makedirs(OUTPUT_DIR, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
attn_implementation="eager"
)
# LoRA on q_proj & v_proj
lora_cfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=LORA_R,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
target_modules=["q_proj", "v_proj"],
)
model_lora = get_peft_model(base_model, lora_cfg)
# Encoder + projection head
class GemmaSimCSE(nn.Module):
def __init__(self, base):
super().__init__()
self.base = base
hs = base.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(hs, PROJ_HIDDEN),
nn.ReLU(),
nn.Linear(PROJ_HIDDEN, PROJ_OUT),
)
def forward(self, input_ids, attention_mask):
out = self.base(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
hidden = out.hidden_states[-1] # (B, T, H)
emb = hidden.mean(dim=1) # mean-pooling
emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6)
z = self.proj(emb)
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
return z / norm
model = GemmaSimCSE(model_lora).to(device)
torch.autograd.set_detect_anomaly(True)
# Load and split dataset
raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train")
raw = raw.filter(lambda x: x["text"].strip() != "")
split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED)
train_ds = split["train"]
val_ds = split["test"]
# Tokenization
def tokenize_fn(batch):
toks = tokenizer(
batch["text"],
max_length=MAX_LENGTH,
truncation=True,
padding="max_length"
)
return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]}
train_ds = train_ds.map(
tokenize_fn,
batched=True,
batch_size=1000,
num_proc=4,
remove_columns=["text"]
)
val_ds = val_ds.map(
tokenize_fn,
batched=True,
batch_size=1000,
num_proc=4,
remove_columns=["text"]
)
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
# Optimizer & scheduler
optimizer = torch.optim.AdamW(
model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
)
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps
)
# Training + validation loop
for epoch in range(1, NUM_EPOCHS + 1):
# --- train ---
model.train()
train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"):
ids = batch["input_ids"].to(device)
mask = batch["attention_mask"].to(device)
emb1 = model(ids, mask)
emb2 = model(ids, mask)
emb = torch.cat([emb1, emb2], dim=0)
sim = (emb @ emb.T) / TEMP
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
# fill diagonal with large negative so self-sim won't be selected
sim.fill_diagonal_(-1e9)
B = emb1.size(0)
# labels: [B..2B-1, 0..B-1]
labels = torch.cat([
torch.arange(B, device=device) + B,
torch.arange(B, device=device)
], dim=0)
loss = F.cross_entropy(sim, labels)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
optimizer.step()
scheduler.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}")
# --- validate ---
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"):
ids = batch["input_ids"].to(device)
mask = batch["attention_mask"].to(device)
emb1 = model(ids, mask)
emb2 = model(ids, mask)
emb = torch.cat([emb1, emb2], dim=0)
sim = (emb @ emb.T) / TEMP
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
sim.fill_diagonal_(-1e9)
B = emb1.size(0)
labels = torch.cat([
torch.arange(B, device=device) + B,
torch.arange(B, device=device)
], dim=0)
loss = F.cross_entropy(sim, labels)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}")
# save checkpoint
ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}")
model_lora.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
# save final model
final_dir = os.path.join(OUTPUT_DIR, "final")
os.makedirs(final_dir, exist_ok=True)
model_lora.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print("Training and validation complete. Final model saved to", final_dir)
if __name__ == "__main__":
freeze_support()
main()