| import multiprocessing
|
| import time
|
| from multiprocessing.managers import Namespace
|
|
|
| import torch
|
| import numpy as np
|
| from omegaconf import DictConfig, open_dict
|
| from torch.optim import Optimizer
|
| from torch.utils.data import DataLoader
|
| from torch.optim.lr_scheduler import (
|
| LRScheduler,
|
| SequentialLR,
|
| LinearLR,
|
| CosineAnnealingLR,
|
| )
|
|
|
| from osuT5.model.osu_t import OsuT
|
| from osuT5.tokenizer import Tokenizer
|
|
|
|
|
| def get_shared_training_state() -> Namespace:
|
| mgr = multiprocessing.Manager()
|
| shared = mgr.Namespace()
|
| shared.current_train_step = 1
|
| shared.current_epoch = 1
|
| shared.last_log = time.time()
|
| shared.current_loss = np.Infinity
|
| shared.best_loss = np.Infinity
|
| return shared
|
|
|
|
|
| def get_model(args: DictConfig, tokenizer: Tokenizer) -> OsuT:
|
| model = OsuT(args, tokenizer)
|
| return model
|
|
|
|
|
| def get_tokenizer(args: DictConfig) -> Tokenizer:
|
| return Tokenizer(args)
|
|
|
|
|
| def get_optimizer(model: OsuT, args: DictConfig) -> Optimizer:
|
| no_decay = ["bias", "LayerNorm", "layernorm", "layer_norm", "ln"]
|
|
|
| optimizer_grouped_parameters = [
|
| {
|
| "params": [
|
| p
|
| for n, p in model.named_parameters()
|
| if not any(nd in n for nd in no_decay)
|
| ],
|
| "weight_decay": args.optim.weight_decay,
|
| },
|
| {
|
| "params": [
|
| p
|
| for n, p in model.named_parameters()
|
| if any(nd in n for nd in no_decay)
|
| ],
|
| "weight_decay": 0.0,
|
| },
|
| ]
|
|
|
| if args.optim.name == 'adamw':
|
| from transformers import AdamW
|
| optimizer = AdamW(
|
| optimizer_grouped_parameters,
|
| lr=args.optim.base_lr,
|
| )
|
| elif args.optim.name == 'adamwscale':
|
| from .copied_utils import AdamWScale
|
| optimizer = AdamWScale(
|
| optimizer_grouped_parameters,
|
| lr=args.optim.base_lr,
|
| )
|
| elif args.optim.name == 'adafactor':
|
| from transformers import Adafactor
|
| optimizer = Adafactor(
|
| optimizer_grouped_parameters,
|
| lr=args.optim.base_lr,
|
| relative_step=False,
|
| )
|
| else:
|
| raise NotImplementedError
|
|
|
| return optimizer
|
|
|
|
|
| def get_scheduler(optimizer: Optimizer, args: DictConfig) -> LRScheduler:
|
| scheduler_p1 = LinearLR(
|
| optimizer,
|
| start_factor=0.5,
|
| end_factor=1,
|
| total_iters=args.optim.warmup_steps,
|
| last_epoch=-1,
|
| )
|
|
|
| scheduler_p2 = CosineAnnealingLR(
|
| optimizer,
|
| T_max=args.optim.total_steps - args.optim.warmup_steps,
|
| eta_min=args.optim.final_cosine,
|
| )
|
|
|
| scheduler = SequentialLR(
|
| optimizer,
|
| schedulers=[scheduler_p1, scheduler_p2],
|
| milestones=[args.optim.warmup_steps],
|
| )
|
|
|
| return scheduler
|
|
|
|
|
|
|
| def worker_init_fn(worker_id: int) -> None:
|
| """
|
| Give each dataloader a unique slice of the full dataset.
|
| """
|
| worker_info = torch.utils.data.get_worker_info()
|
| dataset = worker_info.dataset
|
| overall_start = dataset.start
|
| overall_end = dataset.end
|
|
|
| per_worker = int(
|
| np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
|
| )
|
| dataset.start = overall_start + worker_id * per_worker
|
| dataset.end = min(dataset.start + per_worker, overall_end)
|
|
|