nickmuchi's picture
Upload 17 files
50dd923
raw
history blame
8.53 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pdb
import os
import time
import sys
import torch
from torch.utils.tensorboard import SummaryWriter
import logging
import json
import numpy as np
import torch.distributed as dist
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from src.options import Options
from src import data, beir_utils, slurm, dist_utils, utils, contriever, finetuning_data, inbatch
import train
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger(__name__)
def finetuning(opt, model, optimizer, scheduler, tokenizer, step):
run_stats = utils.WeightedAvgStats()
tb_logger = utils.init_tb_logger(opt.output_dir)
if hasattr(model, "module"):
eval_model = model.module
else:
eval_model = model
eval_model = eval_model.get_encoder()
train_dataset = finetuning_data.Dataset(
datapaths=opt.train_data,
negative_ctxs=opt.negative_ctxs,
negative_hard_ratio=opt.negative_hard_ratio,
negative_hard_min_idx=opt.negative_hard_min_idx,
normalize=opt.eval_normalize_text,
global_rank=dist_utils.get_rank(),
world_size=dist_utils.get_world_size(),
maxload=opt.maxload,
training=True,
)
collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=opt.per_gpu_batch_size,
drop_last=True,
num_workers=opt.num_workers,
collate_fn=collator,
)
train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
evaluate(opt, eval_model, tokenizer, tb_logger, step)
epoch = 1
model.train()
prev_ids, prev_mask = None, None
while step < opt.total_steps:
logger.info(f"Start epoch {epoch}, number of batches: {len(train_dataloader)}")
for i, batch in enumerate(train_dataloader):
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
step += 1
train_loss, iter_stats = model(**batch, stats_prefix="train")
train_loss.backward()
if opt.optim == "sam" or opt.optim == "asam":
optimizer.first_step(zero_grad=True)
sam_loss, _ = model(**batch, stats_prefix="train/sam_opt")
sam_loss.backward()
optimizer.second_step(zero_grad=True)
else:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
run_stats.update(iter_stats)
if step % opt.log_freq == 0:
log = f"{step} / {opt.total_steps}"
for k, v in sorted(run_stats.average_stats.items()):
log += f" | {k}: {v:.3f}"
if tb_logger:
tb_logger.add_scalar(k, v, step)
log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
logger.info(log)
run_stats.reset()
if step % opt.eval_freq == 0:
train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
evaluate(opt, eval_model, tokenizer, tb_logger, step)
if step % opt.save_freq == 0 and dist_utils.get_rank() == 0:
utils.save(
eval_model,
optimizer,
scheduler,
step,
opt,
opt.output_dir,
f"step-{step}",
)
model.train()
if step >= opt.total_steps:
break
epoch += 1
def evaluate(opt, model, tokenizer, tb_logger, step):
dataset = finetuning_data.Dataset(
datapaths=opt.eval_data,
normalize=opt.eval_normalize_text,
global_rank=dist_utils.get_rank(),
world_size=dist_utils.get_world_size(),
maxload=opt.maxload,
training=False,
)
collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=opt.per_gpu_batch_size,
drop_last=False,
num_workers=opt.num_workers,
collate_fn=collator,
)
model.eval()
if hasattr(model, "module"):
model = model.module
correct_samples, total_samples, total_step = 0, 0, 0
all_q, all_g, all_n = [], [], []
with torch.no_grad():
for i, batch in enumerate(dataloader):
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
all_tokens = torch.cat([batch["g_tokens"], batch["n_tokens"]], dim=0)
all_mask = torch.cat([batch["g_mask"], batch["n_mask"]], dim=0)
q_emb = model(input_ids=batch["q_tokens"], attention_mask=batch["q_mask"], normalize=opt.norm_query)
all_emb = model(input_ids=all_tokens, attention_mask=all_mask, normalize=opt.norm_doc)
g_emb, n_emb = torch.split(all_emb, [len(batch["g_tokens"]), len(batch["n_tokens"])])
all_q.append(q_emb)
all_g.append(g_emb)
all_n.append(n_emb)
all_q = torch.cat(all_q, dim=0)
all_g = torch.cat(all_g, dim=0)
all_n = torch.cat(all_n, dim=0)
labels = torch.arange(0, len(all_q), device=all_q.device, dtype=torch.long)
all_sizes = dist_utils.get_varsize(all_g)
all_g = dist_utils.varsize_gather_nograd(all_g)
all_n = dist_utils.varsize_gather_nograd(all_n)
labels = labels + sum(all_sizes[: dist_utils.get_rank()])
scores_pos = torch.einsum("id, jd->ij", all_q, all_g)
scores_neg = torch.einsum("id, jd->ij", all_q, all_n)
scores = torch.cat([scores_pos, scores_neg], dim=-1)
argmax_idx = torch.argmax(scores, dim=1)
sorted_scores, indices = torch.sort(scores, descending=True)
isrelevant = indices == labels[:, None]
rs = [r.cpu().numpy().nonzero()[0] for r in isrelevant]
mrr = np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs])
acc = (argmax_idx == labels).sum() / all_q.size(0)
acc, total = dist_utils.weighted_average(acc, all_q.size(0))
mrr, _ = dist_utils.weighted_average(mrr, all_q.size(0))
acc = 100 * acc
message = []
if dist_utils.is_main():
message = [f"eval acc: {acc:.2f}%", f"eval mrr: {mrr:.3f}"]
logger.info(" | ".join(message))
if tb_logger is not None:
tb_logger.add_scalar(f"eval_acc", acc, step)
tb_logger.add_scalar(f"mrr", mrr, step)
def main():
logger.info("Start")
options = Options()
opt = options.parse()
torch.manual_seed(opt.seed)
slurm.init_distributed_mode(opt)
slurm.init_signal_handler()
directory_exists = os.path.isdir(opt.output_dir)
if dist.is_initialized():
dist.barrier()
os.makedirs(opt.output_dir, exist_ok=True)
if not directory_exists and dist_utils.is_main():
options.print_options(opt)
if dist.is_initialized():
dist.barrier()
utils.init_logger(opt)
step = 0
retriever, tokenizer, retriever_model_id = contriever.load_retriever(opt.model_path, opt.pooling, opt.random_init)
opt.retriever_model_id = retriever_model_id
model = inbatch.InBatch(opt, retriever, tokenizer)
model = model.cuda()
optimizer, scheduler = utils.set_optim(opt, model)
# if dist_utils.is_main():
# utils.save(model, optimizer, scheduler, global_step, 0., opt, opt.output_dir, f"step-{0}")
logger.info(utils.get_parameters(model))
for name, module in model.named_modules():
if isinstance(module, torch.nn.Dropout):
module.p = opt.dropout
if torch.distributed.is_initialized():
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[opt.local_rank],
output_device=opt.local_rank,
find_unused_parameters=False,
)
logger.info("Start training")
finetuning(opt, model, optimizer, scheduler, tokenizer, step)
if __name__ == "__main__":
main()