bamboo-1 / scripts /train.py
rain1024's picture
Upload scripts/train.py with huggingface_hub
bd77427 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch>=2.0.0",
# "transformers>=4.30.0",
# "datasets>=2.14.0",
# "click>=8.0.0",
# "tqdm>=4.60.0",
# "wandb>=0.15.0",
# "python-dotenv>=1.0.0",
# ]
# ///
"""
Training script for Bamboo-1 Vietnamese Dependency Parser.
Supports multiple methods:
- baseline: BiLSTM + Biaffine (Dozat & Manning, 2017)
- trankit: XLM-RoBERTa + Biaffine (Nguyen et al., 2021)
Usage:
uv run scripts/train.py # Default baseline
uv run scripts/train.py --method trankit # Reproduce Trankit
uv run scripts/train.py --method trankit --dataset ud-vtb # Trankit on VTB
"""
import sys
from pathlib import Path
from collections import Counter
from dataclasses import dataclass
from typing import List, Tuple, Optional
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
import click
sys.path.insert(0, str(Path(__file__).parent.parent))
from bamboo1.corpus import UDD1Corpus
from bamboo1.ud_corpus import UDVietnameseVTB
from bamboo1.vndt_corpus import VnDTCorpus
from scripts.cost_estimate import CostTracker, detect_hardware
# ============================================================================
# Data Processing
# ============================================================================
@dataclass
class Sentence:
"""A dependency-parsed sentence."""
words: List[str]
heads: List[int]
rels: List[str]
def read_conllu(path: str) -> List[Sentence]:
"""Read CoNLL-U file and return list of sentences."""
sentences = []
words, heads, rels = [], [], []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
if words:
sentences.append(Sentence(words, heads, rels))
words, heads, rels = [], [], []
elif line.startswith('#'):
continue
else:
parts = line.split('\t')
if '-' in parts[0] or '.' in parts[0]: # Skip multi-word tokens
continue
words.append(parts[1]) # FORM
heads.append(int(parts[6])) # HEAD
rels.append(parts[7]) # DEPREL
if words:
sentences.append(Sentence(words, heads, rels))
return sentences
class Vocabulary:
"""Vocabulary for words, characters, and relations."""
PAD = '<pad>'
UNK = '<unk>'
def __init__(self, min_freq: int = 2):
self.min_freq = min_freq
self.word2idx = {self.PAD: 0, self.UNK: 1}
self.char2idx = {self.PAD: 0, self.UNK: 1}
self.rel2idx = {}
self.idx2rel = {}
def build(self, sentences: List[Sentence]):
"""Build vocabulary from sentences."""
word_counts = Counter()
char_counts = Counter()
rel_counts = Counter()
for sent in sentences:
for word in sent.words:
word_counts[word.lower()] += 1
for char in word:
char_counts[char] += 1
for rel in sent.rels:
rel_counts[rel] += 1
# Words
for word, count in word_counts.items():
if count >= self.min_freq and word not in self.word2idx:
self.word2idx[word] = len(self.word2idx)
# Characters
for char, count in char_counts.items():
if char not in self.char2idx:
self.char2idx[char] = len(self.char2idx)
# Relations
for rel in rel_counts:
if rel not in self.rel2idx:
idx = len(self.rel2idx)
self.rel2idx[rel] = idx
self.idx2rel[idx] = rel
def encode_word(self, word: str) -> int:
return self.word2idx.get(word.lower(), self.word2idx[self.UNK])
def encode_char(self, char: str) -> int:
return self.char2idx.get(char, self.char2idx[self.UNK])
def encode_rel(self, rel: str) -> int:
return self.rel2idx.get(rel, 0)
@property
def n_words(self) -> int:
return len(self.word2idx)
@property
def n_chars(self) -> int:
return len(self.char2idx)
@property
def n_rels(self) -> int:
return len(self.rel2idx)
class DependencyDataset(Dataset):
"""Dataset for dependency parsing."""
def __init__(self, sentences: List[Sentence], vocab: Vocabulary):
self.sentences = sentences
self.vocab = vocab
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
sent = self.sentences[idx]
# Encode words
word_ids = [self.vocab.encode_word(w) for w in sent.words]
# Encode characters
char_ids = [[self.vocab.encode_char(c) for c in w] for w in sent.words]
# Heads and relations
heads = sent.heads
rels = [self.vocab.encode_rel(r) for r in sent.rels]
return word_ids, char_ids, heads, rels
def collate_fn(batch):
"""Collate function for DataLoader."""
word_ids, char_ids, heads, rels = zip(*batch)
# Get lengths
lengths = [len(w) for w in word_ids]
max_len = max(lengths)
# Pad words
word_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, wids in enumerate(word_ids):
word_ids_padded[i, :len(wids)] = torch.tensor(wids)
# Pad characters
max_word_len = max(max(len(c) for c in chars) for chars in char_ids)
char_ids_padded = torch.zeros(len(batch), max_len, max_word_len, dtype=torch.long)
for i, chars in enumerate(char_ids):
for j, c in enumerate(chars):
char_ids_padded[i, j, :len(c)] = torch.tensor(c)
# Pad heads
heads_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, h in enumerate(heads):
heads_padded[i, :len(h)] = torch.tensor(h)
# Pad rels
rels_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, r in enumerate(rels):
rels_padded[i, :len(r)] = torch.tensor(r)
# Mask
mask = torch.zeros(len(batch), max_len, dtype=torch.bool)
for i, l in enumerate(lengths):
mask[i, :l] = True
lengths = torch.tensor(lengths)
return word_ids_padded, char_ids_padded, heads_padded, rels_padded, mask, lengths
# ============================================================================
# Model
# ============================================================================
class CharLSTM(nn.Module):
"""Character-level LSTM embeddings."""
def __init__(self, n_chars: int, char_dim: int = 50, hidden_dim: int = 100):
super().__init__()
self.embed = nn.Embedding(n_chars, char_dim, padding_idx=0)
self.lstm = nn.LSTM(char_dim, hidden_dim // 2, batch_first=True, bidirectional=True)
self.hidden_dim = hidden_dim
def forward(self, chars):
"""
Args:
chars: (batch, seq_len, max_word_len)
Returns:
(batch, seq_len, hidden_dim)
"""
batch, seq_len, max_word_len = chars.shape
# Flatten
chars_flat = chars.view(-1, max_word_len) # (batch * seq_len, max_word_len)
# Get word lengths
word_lens = (chars_flat != 0).sum(dim=1)
word_lens = word_lens.clamp(min=1)
# Embed
char_embeds = self.embed(chars_flat) # (batch * seq_len, max_word_len, char_dim)
# Pack and run LSTM
packed = pack_padded_sequence(char_embeds, word_lens.cpu(), batch_first=True, enforce_sorted=False)
_, (hidden, _) = self.lstm(packed)
# Concatenate forward and backward hidden states
hidden = torch.cat([hidden[0], hidden[1]], dim=-1) # (batch * seq_len, hidden_dim)
return hidden.view(batch, seq_len, self.hidden_dim)
class MLP(nn.Module):
"""Multi-layer perceptron."""
def __init__(self, input_dim: int, hidden_dim: int, dropout: float = 0.33):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.activation = nn.LeakyReLU(0.1)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.activation(self.linear(x)))
class Biaffine(nn.Module):
"""Biaffine attention layer."""
def __init__(self, input_dim: int, output_dim: int = 1, bias_x: bool = True, bias_y: bool = True):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.bias_x = bias_x
self.bias_y = bias_y
self.weight = nn.Parameter(torch.zeros(output_dim, input_dim + bias_x, input_dim + bias_y))
nn.init.xavier_uniform_(self.weight)
def forward(self, x, y):
"""
Args:
x: (batch, seq_len, input_dim) - dependent
y: (batch, seq_len, input_dim) - head
Returns:
(batch, seq_len, seq_len, output_dim) or (batch, seq_len, seq_len) if output_dim=1
"""
if self.bias_x:
x = torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)
if self.bias_y:
y = torch.cat([y, torch.ones_like(y[..., :1])], dim=-1)
# (batch, seq_len, output_dim, input_dim+1)
x = torch.einsum('bxi,oij->bxoj', x, self.weight)
# (batch, seq_len, seq_len, output_dim)
scores = torch.einsum('bxoj,byj->bxyo', x, y)
if self.output_dim == 1:
scores = scores.squeeze(-1)
return scores
class BiaffineDependencyParser(nn.Module):
"""Biaffine Dependency Parser (Dozat & Manning, 2017)."""
def __init__(
self,
n_words: int,
n_chars: int,
n_rels: int,
word_dim: int = 100,
char_dim: int = 50,
char_hidden: int = 100,
lstm_hidden: int = 400,
lstm_layers: int = 3,
arc_hidden: int = 500,
rel_hidden: int = 100,
dropout: float = 0.33,
):
super().__init__()
self.word_embed = nn.Embedding(n_words, word_dim, padding_idx=0)
self.char_lstm = CharLSTM(n_chars, char_dim, char_hidden)
input_dim = word_dim + char_hidden
self.lstm = nn.LSTM(
input_dim, lstm_hidden // 2,
num_layers=lstm_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if lstm_layers > 1 else 0
)
self.mlp_arc_dep = MLP(lstm_hidden, arc_hidden, dropout)
self.mlp_arc_head = MLP(lstm_hidden, arc_hidden, dropout)
self.mlp_rel_dep = MLP(lstm_hidden, rel_hidden, dropout)
self.mlp_rel_head = MLP(lstm_hidden, rel_hidden, dropout)
self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
self.dropout = nn.Dropout(dropout)
self.n_rels = n_rels
def forward(self, words, chars, mask):
"""
Args:
words: (batch, seq_len)
chars: (batch, seq_len, max_word_len)
mask: (batch, seq_len)
Returns:
arc_scores: (batch, seq_len, seq_len)
rel_scores: (batch, seq_len, seq_len, n_rels)
"""
# Embeddings
word_embeds = self.word_embed(words)
char_embeds = self.char_lstm(chars)
embeds = torch.cat([word_embeds, char_embeds], dim=-1)
embeds = self.dropout(embeds)
# BiLSTM
lengths = mask.sum(dim=1).cpu()
packed = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False)
lstm_out, _ = self.lstm(packed)
lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True, total_length=mask.size(1))
lstm_out = self.dropout(lstm_out)
# MLP
arc_dep = self.mlp_arc_dep(lstm_out)
arc_head = self.mlp_arc_head(lstm_out)
rel_dep = self.mlp_rel_dep(lstm_out)
rel_head = self.mlp_rel_head(lstm_out)
# Biaffine
arc_scores = self.arc_attn(arc_dep, arc_head) # (batch, seq_len, seq_len)
rel_scores = self.rel_attn(rel_dep, rel_head) # (batch, seq_len, seq_len, n_rels)
return arc_scores, rel_scores
def loss(self, arc_scores, rel_scores, heads, rels, mask):
"""Compute loss."""
batch_size, seq_len = mask.shape
# Arc loss
arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf'))
arc_loss = F.cross_entropy(
arc_scores[mask].view(-1, seq_len),
heads[mask],
reduction='mean'
)
# Rel loss - select scores for gold heads
rel_scores_gold = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), heads]
rel_loss = F.cross_entropy(
rel_scores_gold[mask],
rels[mask],
reduction='mean'
)
return arc_loss + rel_loss
def decode(self, arc_scores, rel_scores, mask):
"""Decode predictions."""
# Greedy decoding
arc_preds = arc_scores.argmax(dim=-1)
batch_size, seq_len = mask.shape
rel_scores_pred = rel_scores[torch.arange(batch_size).unsqueeze(1), torch.arange(seq_len), arc_preds]
rel_preds = rel_scores_pred.argmax(dim=-1)
return arc_preds, rel_preds
# ============================================================================
# Trankit-style Transformer Parser (XLM-RoBERTa + Biaffine)
# ============================================================================
class TransformerDependencyParser(nn.Module):
"""
Trankit-style dependency parser using XLM-RoBERTa.
Architecture follows Nguyen et al. 2021 EACL:
- XLM-RoBERTa encoder
- Word-level pooling (first subword)
- Biaffine attention for arc/rel prediction
"""
def __init__(
self,
n_rels: int,
encoder: str = "xlm-roberta-base",
arc_hidden: int = 500,
rel_hidden: int = 100,
dropout: float = 0.33,
):
super().__init__()
from transformers import AutoModel, AutoTokenizer
self.encoder_name = encoder
self.tokenizer = AutoTokenizer.from_pretrained(encoder)
self.encoder = AutoModel.from_pretrained(encoder)
self.hidden_size = self.encoder.config.hidden_size
# Biaffine layers
self.mlp_arc_dep = MLP(self.hidden_size, arc_hidden, dropout)
self.mlp_arc_head = MLP(self.hidden_size, arc_hidden, dropout)
self.mlp_rel_dep = MLP(self.hidden_size, rel_hidden, dropout)
self.mlp_rel_head = MLP(self.hidden_size, rel_hidden, dropout)
self.arc_attn = Biaffine(arc_hidden, 1, bias_x=True, bias_y=False)
self.rel_attn = Biaffine(rel_hidden, n_rels, bias_x=True, bias_y=True)
self.dropout = nn.Dropout(dropout)
self.n_rels = n_rels
def encode_batch(self, sentences: List[List[str]], device):
"""Tokenize and encode sentences, return word-level representations."""
batch_size = len(sentences)
max_words = max(len(s) for s in sentences)
# Tokenize each word and track subword positions
all_input_ids = []
all_attention_mask = []
word_starts = [] # (batch, max_words) -> position of first subword
for sent in sentences:
input_ids = [self.tokenizer.cls_token_id]
starts = []
for word in sent:
starts.append(len(input_ids))
tokens = self.tokenizer.encode(word, add_special_tokens=False)
input_ids.extend(tokens if tokens else [self.tokenizer.unk_token_id])
input_ids.append(self.tokenizer.sep_token_id)
all_input_ids.append(input_ids)
word_starts.append(starts)
# Pad sequences
max_len = max(len(ids) for ids in all_input_ids)
padded_ids = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
for i, ids in enumerate(all_input_ids):
padded_ids[i, :len(ids)] = torch.tensor(ids)
attention_mask[i, :len(ids)] = 1
# Encode with transformer
outputs = self.encoder(padded_ids, attention_mask=attention_mask)
hidden = outputs.last_hidden_state # (batch, seq_len, hidden)
# Extract word-level representations (first subword)
word_hidden = torch.zeros(batch_size, max_words, self.hidden_size, device=device)
word_mask = torch.zeros(batch_size, max_words, dtype=torch.bool, device=device)
for i, starts in enumerate(word_starts):
for j, pos in enumerate(starts):
word_hidden[i, j] = hidden[i, pos]
word_mask[i, j] = True
return word_hidden, word_mask
def forward(self, word_hidden, word_mask):
"""Compute arc and relation scores from word representations."""
word_hidden = self.dropout(word_hidden)
# Biaffine scoring
arc_dep = self.mlp_arc_dep(word_hidden)
arc_head = self.mlp_arc_head(word_hidden)
rel_dep = self.mlp_rel_dep(word_hidden)
rel_head = self.mlp_rel_head(word_hidden)
arc_scores = self.arc_attn(arc_dep, arc_head)
rel_scores = self.rel_attn(rel_dep, rel_head)
return arc_scores, rel_scores
def loss(self, arc_scores, rel_scores, heads, rels, mask):
"""Compute cross-entropy loss."""
batch_size, seq_len = mask.shape
# Arc loss
arc_scores = arc_scores.masked_fill(~mask.unsqueeze(2), float('-inf'))
arc_loss = F.cross_entropy(
arc_scores[mask].view(-1, seq_len),
heads[mask],
reduction='mean'
)
# Rel loss
rel_scores_gold = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1),
torch.arange(seq_len, device=mask.device), heads]
rel_loss = F.cross_entropy(
rel_scores_gold[mask],
rels[mask],
reduction='mean'
)
return arc_loss + rel_loss
def decode(self, arc_scores, rel_scores, mask):
"""Greedy decoding."""
arc_preds = arc_scores.argmax(dim=-1)
batch_size, seq_len = mask.shape
rel_scores_pred = rel_scores[torch.arange(batch_size, device=mask.device).unsqueeze(1),
torch.arange(seq_len, device=mask.device), arc_preds]
rel_preds = rel_scores_pred.argmax(dim=-1)
return arc_preds, rel_preds
class TransformerDataset(Dataset):
"""Dataset for transformer-based parser (stores raw sentences)."""
def __init__(self, sentences: List[Sentence], vocab):
self.sentences = sentences
self.vocab = vocab
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
sent = self.sentences[idx]
heads = sent.heads
rels = [self.vocab.encode_rel(r) for r in sent.rels]
return sent.words, heads, rels
def transformer_collate_fn(batch):
"""Collate for transformer-based parser."""
words_list, heads_list, rels_list = zip(*batch)
max_len = max(len(w) for w in words_list)
batch_size = len(batch)
# Pad heads and rels
heads_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
rels_padded = torch.zeros(batch_size, max_len, dtype=torch.long)
mask = torch.zeros(batch_size, max_len, dtype=torch.bool)
for i, (h, r) in enumerate(zip(heads_list, rels_list)):
heads_padded[i, :len(h)] = torch.tensor(h)
rels_padded[i, :len(r)] = torch.tensor(r)
mask[i, :len(h)] = True
return list(words_list), heads_padded, rels_padded, mask
def evaluate_transformer(model, dataloader, device):
"""Evaluate transformer-based model."""
model.eval()
total_arcs = 0
correct_arcs = 0
correct_rels = 0
with torch.no_grad():
for words_list, heads, rels, mask in dataloader:
heads = heads.to(device)
rels = rels.to(device)
mask = mask.to(device)
word_hidden, word_mask = model.encode_batch(words_list, device)
arc_scores, rel_scores = model(word_hidden, word_mask)
arc_preds, rel_preds = model.decode(arc_scores, rel_scores, word_mask)
arc_correct = (arc_preds == heads) & mask
rel_correct = (rel_preds == rels) & mask & arc_correct
total_arcs += mask.sum().item()
correct_arcs += arc_correct.sum().item()
correct_rels += rel_correct.sum().item()
uas = correct_arcs / total_arcs * 100
las = correct_rels / total_arcs * 100
return uas, las
# ============================================================================
# Training
# ============================================================================
def evaluate(model, dataloader, device):
"""Evaluate model and return UAS/LAS."""
model.eval()
total_arcs = 0
correct_arcs = 0
correct_rels = 0
with torch.no_grad():
for batch in dataloader:
words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
arc_scores, rel_scores = model(words, chars, mask)
arc_preds, rel_preds = model.decode(arc_scores, rel_scores, mask)
# Count correct
arc_correct = (arc_preds == heads) & mask
rel_correct = (rel_preds == rels) & mask & arc_correct
total_arcs += mask.sum().item()
correct_arcs += arc_correct.sum().item()
correct_rels += rel_correct.sum().item()
uas = correct_arcs / total_arcs * 100
las = correct_rels / total_arcs * 100
return uas, las
@click.command()
@click.option('--method', type=click.Choice(['baseline', 'trankit']), default='baseline',
help='Parser method: baseline (BiLSTM) or trankit (XLM-RoBERTa)')
@click.option('--dataset', type=click.Choice(['udd1', 'ud-vtb', 'vndt']), default='udd1',
help='Dataset: udd1 (UDD-1), ud-vtb (UD Vietnamese VTB), or vndt (VnDT v1.1)')
@click.option('--encoder', default='xlm-roberta-base',
help='Transformer encoder for trankit method')
@click.option('--output', '-o', default='models/bamboo-1', help='Output directory')
@click.option('--epochs', default=100, type=int, help='Number of epochs')
@click.option('--batch-size', default=32, type=int, help='Batch size')
@click.option('--lr', default=2e-3, type=float, help='Learning rate for baseline')
@click.option('--bert-lr', default=1e-5, type=float, help='Encoder learning rate for trankit')
@click.option('--head-lr', default=1e-4, type=float, help='Head learning rate for trankit')
@click.option('--warmup-steps', default=500, type=int, help='Warmup steps for trankit')
@click.option('--lstm-hidden', default=400, type=int, help='LSTM hidden size (baseline)')
@click.option('--lstm-layers', default=3, type=int, help='LSTM layers (baseline)')
@click.option('--patience', default=10, type=int, help='Early stopping patience')
@click.option('--force-download', is_flag=True, help='Force re-download dataset')
@click.option('--data-dir', default=None, help='Custom data directory')
@click.option('--gpu-type', default='RTX_A4000', help='GPU type for cost estimation')
@click.option('--cost-interval', default=300, type=int, help='Cost report interval in seconds')
@click.option('--wandb', 'use_wandb', is_flag=True, help='Enable W&B logging')
@click.option('--wandb-project', default='bamboo-1', help='W&B project name')
@click.option('--max-time', default=0, type=int, help='Max training time in minutes (0=unlimited)')
@click.option('--sample', default=0, type=int, help='Sample N sentences from each split (0=all)')
@click.option('--eval-every', default=1, type=int, help='Evaluate every N epochs')
@click.option('--fp16', is_flag=True, default=True, help='Use mixed precision training')
def train(method, dataset, encoder, output, epochs, batch_size, lr, bert_lr, head_lr, warmup_steps,
lstm_hidden, lstm_layers, patience, force_download, data_dir, gpu_type, cost_interval,
use_wandb, wandb_project, max_time, sample, eval_every, fp16):
"""Train Bamboo-1 Vietnamese Dependency Parser."""
# Detect hardware
hardware = detect_hardware()
detected_gpu_type = hardware.get_gpu_type()
if gpu_type == "RTX_A4000":
gpu_type = detected_gpu_type
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
click.echo(f"Using device: {device}")
click.echo(f"Hardware: {hardware}")
# CUDA optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Mixed precision
use_amp = fp16 and torch.cuda.is_available()
scaler = torch.amp.GradScaler('cuda') if use_amp else None
if use_amp:
click.echo("Mixed precision (FP16): enabled")
# Initialize wandb
if use_wandb:
import wandb
wandb.init(
project=wandb_project,
config={
"method": method,
"dataset": dataset,
"encoder": encoder if method == "trankit" else "bilstm",
"epochs": epochs,
"batch_size": batch_size,
"lr": lr if method == "baseline" else bert_lr,
"head_lr": head_lr if method == "trankit" else None,
"lstm_hidden": lstm_hidden if method == "baseline" else None,
"lstm_layers": lstm_layers if method == "baseline" else None,
"patience": patience,
"gpu_type": gpu_type,
"hardware": hardware.to_dict(),
}
)
click.echo(f"W&B logging enabled: {wandb.run.url}")
click.echo("=" * 60)
click.echo(f"Bamboo-1: Vietnamese Dependency Parser ({method.upper()})")
click.echo("=" * 60)
# Load corpus
click.echo(f"\nLoading {dataset.upper()} corpus...")
if dataset == 'udd1':
corpus = UDD1Corpus(data_dir=data_dir, force_download=force_download)
elif dataset == 'ud-vtb':
corpus = UDVietnameseVTB(data_dir=data_dir, force_download=force_download)
else: # vndt
corpus = VnDTCorpus(data_dir=data_dir, force_download=force_download)
train_sents = read_conllu(corpus.train)
dev_sents = read_conllu(corpus.dev)
test_sents = read_conllu(corpus.test)
# Sample subset if requested
if sample > 0:
train_sents = train_sents[:sample]
dev_sents = dev_sents[:min(sample // 2, len(dev_sents))]
test_sents = test_sents[:min(sample // 2, len(test_sents))]
click.echo(f" Sampling {sample} sentences...")
click.echo(f" Train: {len(train_sents)} sentences")
click.echo(f" Dev: {len(dev_sents)} sentences")
click.echo(f" Test: {len(test_sents)} sentences")
# Build vocabulary
click.echo("\nBuilding vocabulary...")
vocab = Vocabulary(min_freq=2)
vocab.build(train_sents)
if method == "baseline":
click.echo(f" Words: {vocab.n_words}")
click.echo(f" Chars: {vocab.n_chars}")
click.echo(f" Relations: {vocab.n_rels}")
# Create datasets and model based on method
if method == "trankit":
# Trankit method: XLM-RoBERTa + Biaffine
train_dataset = TransformerDataset(train_sents, vocab)
dev_dataset = TransformerDataset(dev_sents, vocab)
test_dataset = TransformerDataset(test_sents, vocab)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
collate_fn=transformer_collate_fn, num_workers=0)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size,
collate_fn=transformer_collate_fn, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
collate_fn=transformer_collate_fn, num_workers=0)
click.echo(f"\nInitializing model with {encoder}...")
model = TransformerDependencyParser(
n_rels=vocab.n_rels,
encoder=encoder,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
encoder_params = sum(p.numel() for p in model.encoder.parameters())
head_params = n_params - encoder_params
click.echo(f" Total parameters: {n_params:,}")
click.echo(f" Encoder parameters: {encoder_params:,}")
click.echo(f" Head parameters: {head_params:,}")
# Differential learning rates
encoder_params_list = list(model.encoder.parameters())
head_params_list = [p for n, p in model.named_parameters() if 'encoder' not in n]
optimizer = AdamW([
{'params': encoder_params_list, 'lr': bert_lr},
{'params': head_params_list, 'lr': head_lr},
], weight_decay=0.01)
# Learning rate scheduler with warmup
total_steps = len(train_loader) * epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
return max(0.0, (total_steps - step) / (total_steps - warmup_steps))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
eval_fn = evaluate_transformer
else:
# Baseline method: BiLSTM + Biaffine
train_dataset = DependencyDataset(train_sents, vocab)
dev_dataset = DependencyDataset(dev_sents, vocab)
test_dataset = DependencyDataset(test_sents, vocab)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
click.echo("\nInitializing BiLSTM model...")
model = BiaffineDependencyParser(
n_words=vocab.n_words,
n_chars=vocab.n_chars,
n_rels=vocab.n_rels,
lstm_hidden=lstm_hidden,
lstm_layers=lstm_layers,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
click.echo(f" Parameters: {n_params:,}")
optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.9))
scheduler = ExponentialLR(optimizer, gamma=0.75 ** (1 / 5000))
eval_fn = evaluate
# Training
click.echo(f"\nTraining for {epochs} epochs...")
if max_time > 0:
click.echo(f"Time limit: {max_time} minutes")
output_path = Path(output)
output_path.mkdir(parents=True, exist_ok=True)
# Cost tracking
cost_tracker = CostTracker(gpu_type=gpu_type)
cost_tracker.report_interval = cost_interval
cost_tracker.start()
click.echo(f"Cost tracking: {gpu_type} @ ${cost_tracker.hourly_rate}/hr")
best_las = -1
no_improve = 0
time_limit_seconds = max_time * 60 if max_time > 0 else float('inf')
for epoch in range(1, epochs + 1):
# Check time limit
if cost_tracker.elapsed_seconds() >= time_limit_seconds:
click.echo(f"\nTime limit reached ({max_time} minutes)")
break
model.train()
total_loss = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}", leave=False)
for batch in pbar:
optimizer.zero_grad()
if method == "trankit":
words_list, heads, rels, mask = batch
heads = heads.to(device)
rels = rels.to(device)
mask = mask.to(device)
with torch.amp.autocast('cuda', enabled=use_amp):
word_hidden, word_mask = model.encode_batch(words_list, device)
arc_scores, rel_scores = model(word_hidden, word_mask)
loss = model.loss(arc_scores, rel_scores, heads, rels, mask)
else:
words, chars, heads, rels, mask, lengths = [x.to(device) for x in batch]
arc_scores, rel_scores = model(words, chars, mask)
loss = model.loss(arc_scores, rel_scores, heads, rels, mask)
if use_amp and scaler:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), 5.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
# Evaluate (skip if not eval epoch, unless last epoch)
if epoch % eval_every != 0 and epoch != epochs:
avg_loss = total_loss / len(train_loader)
current_lr = optimizer.param_groups[0]['lr']
click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
continue
dev_uas, dev_las = eval_fn(model, dev_loader, device)
# Cost update
progress = epoch / epochs
current_cost = cost_tracker.current_cost()
estimated_total_cost = cost_tracker.estimate_total_cost(progress)
elapsed_minutes = cost_tracker.elapsed_seconds() / 60
cost_status = cost_tracker.update(epoch, epochs)
if cost_status:
click.echo(f" [{cost_status}]")
avg_loss = total_loss / len(train_loader)
click.echo(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
f"Dev UAS: {dev_uas:.2f}% | Dev LAS: {dev_las:.2f}%")
# Log to wandb
if use_wandb:
wandb.log({
"epoch": epoch,
"train/loss": avg_loss,
"dev/uas": dev_uas,
"dev/las": dev_las,
"cost/current_usd": current_cost,
"cost/estimated_total_usd": estimated_total_cost,
"cost/elapsed_minutes": elapsed_minutes,
})
# Save best model
if dev_las >= best_las:
best_las = dev_las
no_improve = 0
if method == "trankit":
config = {
'method': 'trankit',
'encoder': encoder,
'n_rels': vocab.n_rels,
}
else:
config = {
'method': 'baseline',
'n_words': vocab.n_words,
'n_chars': vocab.n_chars,
'n_rels': vocab.n_rels,
'lstm_hidden': lstm_hidden,
'lstm_layers': lstm_layers,
}
torch.save({
'model': model.state_dict(),
'vocab': vocab,
'config': config,
}, output_path / 'model.pt')
click.echo(f" -> Saved best model (LAS: {best_las:.2f}%)")
else:
no_improve += 1
if no_improve >= patience:
click.echo(f"\nEarly stopping after {patience} epochs without improvement")
break
# Final evaluation
click.echo("\nLoading best model for final evaluation...")
checkpoint = torch.load(output_path / 'model.pt', weights_only=False)
model.load_state_dict(checkpoint['model'])
test_uas, test_las = eval_fn(model, test_loader, device)
click.echo(f"\nTest Results:")
click.echo(f" UAS: {test_uas:.2f}%")
click.echo(f" LAS: {test_las:.2f}%")
click.echo(f"\nModel saved to: {output_path}")
# Final cost summary
final_cost = cost_tracker.current_cost()
click.echo(f"\n{cost_tracker.summary(epoch, epochs)}")
# Log final metrics to wandb
if use_wandb:
wandb.log({
"test/uas": test_uas,
"test/las": test_las,
"cost/final_usd": final_cost,
})
wandb.finish()
if __name__ == '__main__':
train()