|
|
|
import os |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import DataLoader |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
get_linear_schedule_with_warmup |
|
) |
|
from peft import LoraConfig, get_peft_model, TaskType |
|
from datasets import load_dataset |
|
from tqdm.auto import tqdm |
|
from multiprocessing import freeze_support |
|
|
|
def main(): |
|
|
|
MODEL_NAME = "google/gemma-3-1b-pt" |
|
DATA_FILE = "text.txt" |
|
BATCH_SIZE = 12 |
|
MAX_LENGTH = 128 |
|
LR = 1e-5 |
|
WEIGHT_DECAY = 0.01 |
|
NUM_EPOCHS = 1 |
|
VAL_RATIO = 0.1 |
|
LORA_R = 8 |
|
LORA_ALPHA = 16 |
|
LORA_DROPOUT = 0.0 |
|
PROJ_HIDDEN = 512 |
|
PROJ_OUT = 256 |
|
TEMP = 0.05 |
|
OUTPUT_DIR = "stage1_simcse" |
|
GRAD_CLIP_NORM = 1.0 |
|
SIM_CLAMP_MIN = -10.0 |
|
SIM_CLAMP_MAX = 10.0 |
|
SEED = 42 |
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
attn_implementation="eager" |
|
) |
|
|
|
|
|
lora_cfg = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=LORA_R, |
|
lora_alpha=LORA_ALPHA, |
|
lora_dropout=LORA_DROPOUT, |
|
target_modules=["q_proj", "v_proj"], |
|
) |
|
model_lora = get_peft_model(base_model, lora_cfg) |
|
|
|
|
|
class GemmaSimCSE(nn.Module): |
|
def __init__(self, base): |
|
super().__init__() |
|
self.base = base |
|
hs = base.config.hidden_size |
|
self.proj = nn.Sequential( |
|
nn.Linear(hs, PROJ_HIDDEN), |
|
nn.ReLU(), |
|
nn.Linear(PROJ_HIDDEN, PROJ_OUT), |
|
) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
out = self.base( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
return_dict=True |
|
) |
|
hidden = out.hidden_states[-1] |
|
emb = hidden.mean(dim=1) |
|
emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6) |
|
z = self.proj(emb) |
|
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) |
|
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) |
|
return z / norm |
|
|
|
model = GemmaSimCSE(model_lora).to(device) |
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
|
|
raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train") |
|
raw = raw.filter(lambda x: x["text"].strip() != "") |
|
split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED) |
|
train_ds = split["train"] |
|
val_ds = split["test"] |
|
|
|
|
|
def tokenize_fn(batch): |
|
toks = tokenizer( |
|
batch["text"], |
|
max_length=MAX_LENGTH, |
|
truncation=True, |
|
padding="max_length" |
|
) |
|
return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]} |
|
|
|
train_ds = train_ds.map( |
|
tokenize_fn, |
|
batched=True, |
|
batch_size=1000, |
|
num_proc=4, |
|
remove_columns=["text"] |
|
) |
|
val_ds = val_ds.map( |
|
tokenize_fn, |
|
batched=True, |
|
batch_size=1000, |
|
num_proc=4, |
|
remove_columns=["text"] |
|
) |
|
|
|
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
|
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
|
|
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) |
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY |
|
) |
|
total_steps = len(train_loader) * NUM_EPOCHS |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=int(0.1 * total_steps), |
|
num_training_steps=total_steps |
|
) |
|
|
|
|
|
for epoch in range(1, NUM_EPOCHS + 1): |
|
|
|
model.train() |
|
train_loss = 0.0 |
|
for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"): |
|
ids = batch["input_ids"].to(device) |
|
mask = batch["attention_mask"].to(device) |
|
|
|
emb1 = model(ids, mask) |
|
emb2 = model(ids, mask) |
|
emb = torch.cat([emb1, emb2], dim=0) |
|
sim = (emb @ emb.T) / TEMP |
|
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
|
|
|
|
|
sim.fill_diagonal_(-1e9) |
|
|
|
B = emb1.size(0) |
|
|
|
labels = torch.cat([ |
|
torch.arange(B, device=device) + B, |
|
torch.arange(B, device=device) |
|
], dim=0) |
|
|
|
loss = F.cross_entropy(sim, labels) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
train_loss += loss.item() |
|
|
|
avg_train_loss = train_loss / len(train_loader) |
|
print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}") |
|
|
|
|
|
model.eval() |
|
val_loss = 0.0 |
|
with torch.no_grad(): |
|
for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"): |
|
ids = batch["input_ids"].to(device) |
|
mask = batch["attention_mask"].to(device) |
|
|
|
emb1 = model(ids, mask) |
|
emb2 = model(ids, mask) |
|
emb = torch.cat([emb1, emb2], dim=0) |
|
sim = (emb @ emb.T) / TEMP |
|
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
|
sim.fill_diagonal_(-1e9) |
|
|
|
B = emb1.size(0) |
|
labels = torch.cat([ |
|
torch.arange(B, device=device) + B, |
|
torch.arange(B, device=device) |
|
], dim=0) |
|
|
|
loss = F.cross_entropy(sim, labels) |
|
val_loss += loss.item() |
|
|
|
avg_val_loss = val_loss / len(val_loader) |
|
print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}") |
|
|
|
|
|
ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}") |
|
model_lora.save_pretrained(ckpt_dir) |
|
tokenizer.save_pretrained(ckpt_dir) |
|
|
|
|
|
final_dir = os.path.join(OUTPUT_DIR, "final") |
|
os.makedirs(final_dir, exist_ok=True) |
|
model_lora.save_pretrained(final_dir) |
|
tokenizer.save_pretrained(final_dir) |
|
print("Training and validation complete. Final model saved to", final_dir) |
|
|
|
if __name__ == "__main__": |
|
freeze_support() |
|
main() |
|
|