| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer |
| import os |
| import argparse |
|
|
| from src.config import ModelConfig, TrainConfig |
| from src.models.autoencoder import ReshapedAutoencoder,ResidualAutoencoder |
| from src.trainer import Trainer |
| from src.utils.data_utils import prepare_data |
|
|
| def _pick_stop_id(tokenizer): |
| return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--save_dir", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/robust_checkpoints", help="Directory to save checkpoints") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
| print(f"Checkpoints will be saved to: {args.save_dir}") |
|
|
| |
| m_cfg = ModelConfig( |
| encoder_name='../jina-embeddings-v2-base-code', |
| latent_dim=512, |
| max_seq_len=128 |
| ) |
| |
| t_cfg = TrainConfig( |
| batch_size=16, |
| num_epochs_ae=20, |
| grad_accum_steps=4, |
| use_amp=False, |
| lr_ae=1e-4 |
| ) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, local_files_only=True, trust_remote_code=False) |
| train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") |
| |
| |
| ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() |
| |
| |
| if ae.encoder.config.pad_token_id is None: |
| ae.encoder.config.pad_token_id = tokenizer.pad_token_id |
|
|
| |
| |
| trainer = Trainer( |
| ae=ae, |
| flow=None, |
| cfg=t_cfg, |
| loader=train_loader, |
| pad_id=tokenizer.pad_token_id, |
| stop_id=_pick_stop_id(tokenizer) |
| ) |
| |
| |
| opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) |
| |
| |
| best_ae_loss = float('inf') |
| print("\n>>> Start Training Autoencoder...") |
| |
| for epoch in range(t_cfg.num_epochs_ae): |
| |
| |
| loss = trainer.train_ae_combined(opt_ae, epoch, t_cfg.num_epochs_ae) |
| print(f"AE Epoch {epoch}: Loss {loss:.4f}") |
| |
| |
| if loss < best_ae_loss: |
| best_ae_loss = loss |
| save_path = os.path.join(args.save_dir, "ae_best.pt") |
| torch.save(ae.state_dict(), save_path) |
| print(f" Saved Best AE to {save_path}") |
| |
| |
| torch.save(ae.state_dict(), os.path.join(args.save_dir, "ae_last.pt")) |
| |
| print(f"AE Training Done. Best Loss: {best_ae_loss:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |