|
|
| import os
|
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| import sys
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.utils.data import DataLoader
|
| from transformers import (
|
| AutoTokenizer,
|
| AutoModelForCausalLM,
|
| get_linear_schedule_with_warmup
|
| )
|
| from peft import LoraConfig, get_peft_model, TaskType
|
| from datasets import load_dataset
|
| from tqdm.auto import tqdm
|
| from multiprocessing import freeze_support
|
|
|
| def main():
|
|
|
| MODEL_NAME = "google/gemma-3-1b-pt"
|
| DATA_FILE = "text.txt"
|
| BATCH_SIZE = 12
|
| MAX_LENGTH = 128
|
| LR = 1e-5
|
| WEIGHT_DECAY = 0.01
|
| NUM_EPOCHS = 1
|
| VAL_RATIO = 0.1
|
| LORA_R = 8
|
| LORA_ALPHA = 16
|
| LORA_DROPOUT = 0.0
|
| PROJ_HIDDEN = 512
|
| TEMP = 0.05
|
| OUTPUT_DIR = "stage1_simcse"
|
| GRAD_CLIP_NORM = 1.0
|
| SIM_CLAMP_MIN = -10.0
|
| SIM_CLAMP_MAX = 10.0
|
| SEED = 42
|
|
|
| os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| if device.type == "cuda":
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
| torch.backends.cudnn.benchmark = True
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
| base_model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_NAME,
|
| attn_implementation="eager"
|
| )
|
|
|
|
|
| lora_cfg = LoraConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| inference_mode=False,
|
| r=LORA_R,
|
| lora_alpha=LORA_ALPHA,
|
| lora_dropout=LORA_DROPOUT,
|
| target_modules=["q_proj", "v_proj"],
|
| )
|
| model_lora = get_peft_model(base_model, lora_cfg)
|
|
|
|
|
| class GemmaSimCSE(nn.Module):
|
| def __init__(self, base):
|
| super().__init__()
|
| self.base = base
|
| hs = base.config.hidden_size
|
| self.proj = nn.Sequential(
|
| nn.Linear(hs, PROJ_HIDDEN),
|
| nn.ReLU(),
|
| nn.Linear(PROJ_HIDDEN, hs),
|
| )
|
|
|
| def forward(self, input_ids, attention_mask):
|
| out = self.base(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| output_hidden_states=True,
|
| return_dict=True
|
| )
|
| hidden = out.hidden_states[-1]
|
| emb = hidden.mean(dim=1)
|
| emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
| z = self.proj(emb)
|
| z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
| norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
|
| return z / norm
|
|
|
| model = GemmaSimCSE(model_lora).to(device)
|
| torch.autograd.set_detect_anomaly(True)
|
|
|
|
|
| raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train")
|
| raw = raw.filter(lambda x: x["text"].strip() != "")
|
| split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED)
|
| train_ds = split["train"]
|
| val_ds = split["test"]
|
|
|
|
|
| def tokenize_fn(batch):
|
| toks = tokenizer(
|
| batch["text"],
|
| max_length=MAX_LENGTH,
|
| truncation=True,
|
| padding="max_length"
|
| )
|
| return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]}
|
|
|
| train_ds = train_ds.map(
|
| tokenize_fn,
|
| batched=True,
|
| batch_size=1000,
|
| num_proc=4,
|
| remove_columns=["text"]
|
| )
|
| val_ds = val_ds.map(
|
| tokenize_fn,
|
| batched=True,
|
| batch_size=1000,
|
| num_proc=4,
|
| remove_columns=["text"]
|
| )
|
|
|
| train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
| val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
|
|
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
|
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
|
| )
|
| total_steps = len(train_loader) * NUM_EPOCHS
|
| scheduler = get_linear_schedule_with_warmup(
|
| optimizer,
|
| num_warmup_steps=int(0.1 * total_steps),
|
| num_training_steps=total_steps
|
| )
|
|
|
|
|
| for epoch in range(1, NUM_EPOCHS + 1):
|
|
|
| model.train()
|
| train_loss = 0.0
|
| for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"):
|
| ids = batch["input_ids"].to(device)
|
| mask = batch["attention_mask"].to(device)
|
|
|
| emb1 = model(ids, mask)
|
| emb2 = model(ids, mask)
|
| emb = torch.cat([emb1, emb2], dim=0)
|
| sim = (emb @ emb.T) / TEMP
|
| sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
|
| sim.fill_diagonal_(-1e9)
|
|
|
| B = emb1.size(0)
|
| labels = torch.cat([
|
| torch.arange(B, device=device) + B,
|
| torch.arange(B, device=device)
|
| ], dim=0)
|
|
|
| loss = F.cross_entropy(sim, labels)
|
| optimizer.zero_grad()
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
|
| optimizer.step()
|
| scheduler.step()
|
| train_loss += loss.item()
|
|
|
| avg_train_loss = train_loss / len(train_loader)
|
| print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}")
|
|
|
|
|
| model.eval()
|
| val_loss = 0.0
|
| with torch.no_grad():
|
| for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"):
|
| ids = batch["input_ids"].to(device)
|
| mask = batch["attention_mask"].to(device)
|
|
|
| emb1 = model(ids, mask)
|
| emb2 = model(ids, mask)
|
| emb = torch.cat([emb1, emb2], dim=0)
|
| sim = (emb @ emb.T) / TEMP
|
| sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
|
| sim.fill_diagonal_(-1e9)
|
|
|
| B = emb1.size(0)
|
| labels = torch.cat([
|
| torch.arange(B, device=device) + B,
|
| torch.arange(B, device=device)
|
| ], dim=0)
|
|
|
| loss = F.cross_entropy(sim, labels)
|
| val_loss += loss.item()
|
|
|
| avg_val_loss = val_loss / len(val_loader)
|
| print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}")
|
|
|
|
|
| ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}")
|
| model_lora.save_pretrained(ckpt_dir)
|
| tokenizer.save_pretrained(ckpt_dir)
|
|
|
|
|
| final_dir = os.path.join(OUTPUT_DIR, "final")
|
| os.makedirs(final_dir, exist_ok=True)
|
| model_lora.save_pretrained(final_dir)
|
| tokenizer.save_pretrained(final_dir)
|
| print("Training and validation complete. Final model saved to", final_dir)
|
|
|
| if __name__ == "__main__":
|
| freeze_support()
|
| main()
|
|
|