|
|
|
|
|
import os |
|
import time |
|
import sys |
|
import torch |
|
import logging |
|
import json |
|
import numpy as np |
|
import random |
|
import pickle |
|
|
|
import torch.distributed as dist |
|
from torch.utils.data import DataLoader, RandomSampler |
|
|
|
from src.options import Options |
|
from src import data, beir_utils, slurm, dist_utils, utils |
|
from src import moco, inbatch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def train(opt, model, optimizer, scheduler, step): |
|
|
|
run_stats = utils.WeightedAvgStats() |
|
|
|
tb_logger = utils.init_tb_logger(opt.output_dir) |
|
|
|
logger.info("Data loading") |
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
tokenizer = model.module.tokenizer |
|
else: |
|
tokenizer = model.tokenizer |
|
collator = data.Collator(opt=opt) |
|
train_dataset = data.load_data(opt, tokenizer) |
|
logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}") |
|
|
|
train_sampler = RandomSampler(train_dataset) |
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
sampler=train_sampler, |
|
batch_size=opt.per_gpu_batch_size, |
|
drop_last=True, |
|
num_workers=opt.num_workers, |
|
collate_fn=collator, |
|
) |
|
|
|
epoch = 1 |
|
|
|
model.train() |
|
while step < opt.total_steps: |
|
train_dataset.generate_offset() |
|
|
|
logger.info(f"Start epoch {epoch}") |
|
for i, batch in enumerate(train_dataloader): |
|
step += 1 |
|
|
|
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()} |
|
train_loss, iter_stats = model(**batch, stats_prefix="train") |
|
|
|
train_loss.backward() |
|
optimizer.step() |
|
|
|
scheduler.step() |
|
model.zero_grad() |
|
|
|
run_stats.update(iter_stats) |
|
|
|
if step % opt.log_freq == 0: |
|
log = f"{step} / {opt.total_steps}" |
|
for k, v in sorted(run_stats.average_stats.items()): |
|
log += f" | {k}: {v:.3f}" |
|
if tb_logger: |
|
tb_logger.add_scalar(k, v, step) |
|
log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}" |
|
log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB" |
|
|
|
logger.info(log) |
|
run_stats.reset() |
|
|
|
if step % opt.eval_freq == 0: |
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
encoder = model.module.get_encoder() |
|
else: |
|
encoder = model.get_encoder() |
|
eval_model( |
|
opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step |
|
) |
|
|
|
if dist_utils.is_main(): |
|
utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"lastlog") |
|
|
|
model.train() |
|
|
|
if dist_utils.is_main() and step % opt.save_freq == 0: |
|
utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}") |
|
|
|
if step > opt.total_steps: |
|
break |
|
epoch += 1 |
|
|
|
|
|
def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step): |
|
for datasetname in opt.eval_datasets: |
|
metrics = beir_utils.evaluate_model( |
|
query_encoder, |
|
doc_encoder, |
|
tokenizer, |
|
dataset=datasetname, |
|
batch_size=opt.per_gpu_eval_batch_size, |
|
norm_doc=opt.norm_doc, |
|
norm_query=opt.norm_query, |
|
beir_dir=opt.eval_datasets_dir, |
|
score_function=opt.score_function, |
|
lower_case=opt.lower_case, |
|
normalize_text=opt.eval_normalize_text, |
|
) |
|
|
|
message = [] |
|
if dist_utils.is_main(): |
|
for metric in ["NDCG@10", "Recall@10", "Recall@100"]: |
|
message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}") |
|
if tb_logger is not None: |
|
tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step) |
|
logger.info(" | ".join(message)) |
|
|
|
|
|
if __name__ == "__main__": |
|
logger.info("Start") |
|
|
|
options = Options() |
|
opt = options.parse() |
|
|
|
torch.manual_seed(opt.seed) |
|
slurm.init_distributed_mode(opt) |
|
slurm.init_signal_handler() |
|
|
|
directory_exists = os.path.isdir(opt.output_dir) |
|
if dist.is_initialized(): |
|
dist.barrier() |
|
os.makedirs(opt.output_dir, exist_ok=True) |
|
if not directory_exists and dist_utils.is_main(): |
|
options.print_options(opt) |
|
if dist.is_initialized(): |
|
dist.barrier() |
|
utils.init_logger(opt) |
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
if opt.contrastive_mode == "moco": |
|
model_class = moco.MoCo |
|
elif opt.contrastive_mode == "inbatch": |
|
model_class = inbatch.InBatch |
|
else: |
|
raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised") |
|
|
|
if not directory_exists and opt.model_path == "none": |
|
model = model_class(opt) |
|
model = model.cuda() |
|
optimizer, scheduler = utils.set_optim(opt, model) |
|
step = 0 |
|
elif directory_exists: |
|
model_path = os.path.join(opt.output_dir, "checkpoint", "latest") |
|
model, optimizer, scheduler, opt_checkpoint, step = utils.load( |
|
model_class, |
|
model_path, |
|
opt, |
|
reset_params=False, |
|
) |
|
logger.info(f"Model loaded from {opt.output_dir}") |
|
else: |
|
model, optimizer, scheduler, opt_checkpoint, step = utils.load( |
|
model_class, |
|
opt.model_path, |
|
opt, |
|
reset_params=False if opt.continue_training else True, |
|
) |
|
if not opt.continue_training: |
|
step = 0 |
|
logger.info(f"Model loaded from {opt.model_path}") |
|
|
|
logger.info(utils.get_parameters(model)) |
|
|
|
if dist.is_initialized(): |
|
model = torch.nn.parallel.DistributedDataParallel( |
|
model, |
|
device_ids=[opt.local_rank], |
|
output_device=opt.local_rank, |
|
find_unused_parameters=False, |
|
) |
|
dist.barrier() |
|
|
|
logger.info("Start training") |
|
train(opt, model, optimizer, scheduler, step) |
|
|