|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import json
|
| import argparse
|
| import inspect
|
| from typing import List, Dict, Any
|
|
|
| import numpy as np
|
| import torch
|
| from torch.utils.data import Dataset
|
|
|
| from transformers import (
|
| AutoTokenizer,
|
| AutoConfig,
|
| AutoModelForTokenClassification,
|
| Trainer,
|
| TrainingArguments,
|
| set_seed,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| class JsonlTokenDataset(Dataset):
|
| """Loads JSONL produced by prep.py. Masks special tokens in labels to -100."""
|
| def __init__(self, path: str, tokenizer: AutoTokenizer):
|
| self.path = path
|
| self.tokenizer = tokenizer
|
| self.samples: List[Dict[str, Any]] = []
|
| with open(self.path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| rec = json.loads(line)
|
| self.samples.append(rec)
|
|
|
|
|
| for rec in self.samples:
|
| input_ids = rec["input_ids"]
|
| labels = rec["labels"]
|
| try:
|
| special_mask = tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True)
|
| except Exception:
|
| spec = set(tokenizer.all_special_ids or [])
|
| special_mask = [1 if t in spec else 0 for t in input_ids]
|
| rec["labels"] = [-100 if sm == 1 else int(l) for l, sm in zip(labels, special_mask)]
|
|
|
| def __len__(self): return len(self.samples)
|
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| r = self.samples[idx]
|
| return {
|
| "input_ids": torch.tensor(r["input_ids"], dtype=torch.long),
|
| "attention_mask": torch.tensor(r["attention_mask"], dtype=torch.long),
|
| "labels": torch.tensor(r["labels"], dtype=torch.long),
|
| }
|
|
|
|
|
|
|
|
|
|
|
| class SimpleTokenCollator:
|
| """Pads input_ids with pad_token_id, attention_mask with 0, labels with -100."""
|
| def __init__(self, tokenizer: AutoTokenizer, pad_to_multiple_of: int = None):
|
| self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
| self.pad_to_multiple = pad_to_multiple_of
|
|
|
| def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| ids = [f["input_ids"].tolist() for f in features]
|
| att = [f["attention_mask"].tolist() for f in features]
|
| lab = [f["labels"].tolist() for f in features]
|
| max_len = max(len(x) for x in ids)
|
| if self.pad_to_multiple and max_len % self.pad_to_multiple != 0:
|
| max_len = ((max_len // self.pad_to_multiple) + 1) * self.pad_to_multiple
|
|
|
| def pad(seq, val): return seq + [val] * (max_len - len(seq))
|
| ids = [pad(x, self.pad_id) for x in ids]
|
| att = [pad(x, 0) for x in att]
|
| lab = [pad(x, -100) for x in lab]
|
| return {
|
| "input_ids": torch.tensor(ids, dtype=torch.long),
|
| "attention_mask": torch.tensor(att, dtype=torch.long),
|
| "labels": torch.tensor(lab, dtype=torch.long),
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def compute_class_weights(dataset: JsonlTokenDataset) -> torch.Tensor:
|
| pos = 0; neg = 0
|
| for rec in dataset.samples:
|
| for l in rec["labels"]:
|
| if l == -100: continue
|
| if l == 1: pos += 1
|
| else: neg += 1
|
| return torch.tensor([1.0, (neg / max(1, pos)) if pos > 0 else 1.0], dtype=torch.float)
|
|
|
|
|
|
|
|
|
|
|
| def compute_metrics_fn(eval_pred):
|
| logits, labels = eval_pred
|
| preds = np.argmax(logits, axis=-1)
|
| y_true, y_pred = [], []
|
| for p, l in zip(preds, labels):
|
| for pi, li in zip(p, l):
|
| if li == -100: continue
|
| y_true.append(int(li)); y_pred.append(int(pi))
|
| if not y_true:
|
| return {"accuracy":0.0,"precision":0.0,"recall":0.0,"f1":0.0,"pos_rate_true":0.0,"pos_rate_pred":0.0}
|
| y_true = np.array(y_true); y_pred = np.array(y_pred)
|
| tp = int(np.sum((y_pred==1)&(y_true==1))); fp = int(np.sum((y_pred==1)&(y_true==0)))
|
| tn = int(np.sum((y_pred==0)&(y_true==0))); fn = int(np.sum((y_pred==0)&(y_true==1)))
|
| acc = (tp+tn)/max(1,tp+tn+fp+fn); prec = tp/max(1,tp+fp); rec = tp/max(1,tp+fn)
|
| f1 = (2*prec*rec/max(1e-12,prec+rec)) if (prec+rec)>0 else 0.0
|
| return {"accuracy":acc,"precision":prec,"recall":rec,"f1":f1,
|
| "pos_rate_true":float(np.mean(y_true)),"pos_rate_pred":float(np.mean(y_pred))}
|
|
|
|
|
|
|
|
|
|
|
| class WeightedCELossTrainer(Trainer):
|
| def __init__(self, class_weights: torch.Tensor = None, **kwargs):
|
| super().__init__(**kwargs); self.class_weights = class_weights
|
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
| labels = inputs["labels"]
|
| outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
| logits = outputs.logits
|
| loss_fct = torch.nn.CrossEntropyLoss(
|
| weight=(self.class_weights.to(logits.device) if self.class_weights is not None else None)
|
| )
|
| mask = labels.ne(-100)
|
| loss = loss_fct(logits.view(-1,2)[mask.view(-1)], labels.view(-1)[mask.view(-1)])
|
| return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
|
|
|
|
|
| def build_training_arguments(args) -> TrainingArguments:
|
| sig = set(inspect.signature(TrainingArguments.__init__).parameters.keys())
|
|
|
| supports_eval_strategy = "evaluation_strategy" in sig
|
| supports_save_strategy = "save_strategy" in sig
|
| supports_log_strategy = "logging_strategy" in sig
|
| supports_report_to = "report_to" in sig
|
| supports_load_best = "load_best_model_at_end" in sig
|
| supports_metric_forbest = "metric_for_best_model" in sig
|
| supports_workers = "dataloader_num_workers" in sig
|
|
|
| kw = {
|
| "output_dir": args.output_dir,
|
| "num_train_epochs": args.epochs,
|
| "per_device_train_batch_size": args.train_batch_size,
|
| "per_device_eval_batch_size": args.eval_batch_size,
|
| "learning_rate": args.lr,
|
| "weight_decay": args.weight_decay,
|
| "logging_steps": args.logging_steps,
|
| "eval_steps": args.eval_steps,
|
| "save_steps": args.save_steps,
|
| "save_total_limit": 2,
|
| "seed": args.seed,
|
| "gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| "fp16": args.fp16,
|
| "bf16": args.bf16,
|
| "gradient_checkpointing": args.gradient_checkpointing,
|
| "log_level": "info",
|
| }
|
| if supports_workers:
|
| kw["dataloader_num_workers"] = args.num_workers
|
| if supports_report_to:
|
| kw["report_to"] = (None if args.report_to == "none" else ["wandb"])
|
|
|
|
|
| if supports_eval_strategy and supports_save_strategy:
|
| kw["evaluation_strategy"] = "steps"
|
| kw["save_strategy"] = "steps"
|
| if supports_log_strategy:
|
| kw["logging_strategy"] = "steps"
|
| if supports_load_best:
|
| kw["load_best_model_at_end"] = True
|
| if supports_metric_forbest:
|
| kw["metric_for_best_model"] = "f1"
|
| if "greater_is_better" in sig:
|
| kw["greater_is_better"] = True
|
| else:
|
| for k in ("evaluation_strategy","save_strategy","logging_strategy","load_best_model_at_end",
|
| "metric_for_best_model","greater_is_better"):
|
| kw.pop(k, None)
|
| if "evaluate_during_training" in sig and args.eval_steps > 0:
|
| kw["evaluate_during_training"] = True
|
|
|
| kw = {k: v for k, v in kw.items() if k in sig}
|
| return TrainingArguments(**kw)
|
|
|
|
|
|
|
|
|
|
|
| def parse_args():
|
| ap = argparse.ArgumentParser(description="Train binary token classification model for link anchors.")
|
| ap.add_argument("--model_name", default="microsoft/mdeberta-v3-base", help="HF model name or local path.")
|
| ap.add_argument("--train_path", default="train_windows.jsonl", help="Training JSONL.")
|
| ap.add_argument("--val_path", default="val_windows.jsonl", help="Validation JSONL.")
|
| ap.add_argument("--output_dir", default="model_link_token_cls", help="Output directory.")
|
|
|
| ap.add_argument("--epochs", type=int, default=3)
|
| ap.add_argument("--lr", type=float, default=2e-5)
|
| ap.add_argument("--weight_decay", type=float, default=0.01)
|
| ap.add_argument("--train_batch_size", type=int, default=16)
|
| ap.add_argument("--eval_batch_size", type=int, default=32)
|
| ap.add_argument("--logging_steps", type=int, default=50)
|
| ap.add_argument("--eval_steps", type=int, default=500)
|
| ap.add_argument("--save_steps", type=int, default=500)
|
|
|
| ap.add_argument("--seed", type=int, default=42)
|
| ap.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| ap.add_argument("--fp16", action="store_true")
|
| ap.add_argument("--bf16", action="store_true")
|
| ap.add_argument("--gradient_checkpointing", action="store_true")
|
| ap.add_argument("--report_to", default="wandb", choices=["wandb","none"])
|
| ap.add_argument("--pad_to_multiple_of", type=int, default=8)
|
| ap.add_argument("--num_workers", type=int, default=2)
|
| return ap.parse_args()
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| args = parse_args()
|
| set_seed(args.seed)
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
|
|
|
| train_ds = JsonlTokenDataset(args.train_path, tokenizer)
|
| val_ds = JsonlTokenDataset(args.val_path, tokenizer)
|
|
|
| id2label = {0: "O", 1: "LINK"}
|
| label2id = {"O": 0, "LINK": 1}
|
| config = AutoConfig.from_pretrained(args.model_name, num_labels=2, id2label=id2label, label2id=label2id)
|
| model = AutoModelForTokenClassification.from_pretrained(args.model_name, config=config)
|
|
|
| class_weights = compute_class_weights(train_ds)
|
|
|
| collator = SimpleTokenCollator(
|
| tokenizer=tokenizer,
|
| pad_to_multiple_of=(args.pad_to_multiple_of if torch.cuda.is_available() else None),
|
| )
|
|
|
| training_args = build_training_arguments(args)
|
|
|
| trainer = WeightedCELossTrainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=train_ds,
|
| eval_dataset=val_ds,
|
| data_collator=collator,
|
| tokenizer=tokenizer,
|
| compute_metrics=compute_metrics_fn,
|
| class_weights=class_weights,
|
| )
|
|
|
| trainer.train()
|
| trainer.save_model(args.output_dir)
|
| tokenizer.save_pretrained(args.output_dir)
|
|
|
| metrics = trainer.evaluate()
|
| trainer.log_metrics("eval", metrics)
|
| trainer.save_metrics("eval", metrics)
|
| trainer.save_state()
|
|
|
| with open(os.path.join(args.output_dir, "label_map.json"), "w", encoding="utf-8") as f:
|
| json.dump({"0":"O","1":"LINK"}, f)
|
|
|
| print("=== Training complete ===")
|
| print(f"Output dir: {args.output_dir}")
|
| print(f"Class weights [neg, pos]: [{class_weights[0].item():.4f}, {class_weights[1].item():.4f}]")
|
| print(f"Eval metrics: {metrics}")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|