欧卫
'add_app_files'
58627fa
raw
history blame
No virus
5.8 kB
import time
import torch
import random
import torch.nn as nn
import numpy as np
from transformers import AdamW, get_linear_schedule_with_warmup
from colbert.infra import ColBERTConfig
from colbert.training.rerank_batcher import RerankBatcher
from colbert.utils.amp import MixedPrecisionManager
from colbert.training.lazy_batcher import LazyBatcher
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.modeling.reranker.electra import ElectraReranker
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints
def train(config: ColBERTConfig, triples, queries=None, collection=None):
config.checkpoint = config.checkpoint or 'bert-base-uncased'
if config.rank < 1:
config.help()
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)
assert config.bsize % config.nranks == 0, (config.bsize, config.nranks)
config.bsize = config.bsize // config.nranks
print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps)
if collection is not None:
if config.reranker:
reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
raise NotImplementedError()
if not config.reranker:
colbert = ColBERT(name=config.checkpoint, colbert_config=config)
else:
colbert = ElectraReranker.from_pretrained(config.checkpoint)
colbert = colbert.to(DEVICE)
colbert.train()
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank],
output_device=config.rank,
find_unused_parameters=True)
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8)
optimizer.zero_grad()
scheduler = None
if config.warmup is not None:
print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.")
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup,
num_training_steps=config.maxsteps)
warmup_bert = config.warmup_bert
if warmup_bert is not None:
set_bert_grad(colbert, False)
amp = MixedPrecisionManager(config.amp)
labels = torch.zeros(config.bsize, dtype=torch.long, device=DEVICE)
start_time = time.time()
train_loss = None
train_loss_mu = 0.999
start_batch_idx = 0
# if config.resume:
# assert config.checkpoint is not None
# start_batch_idx = checkpoint['batch']
# reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])
for batch_idx, BatchSteps in zip(range(start_batch_idx, config.maxsteps), reader):
if (warmup_bert is not None) and warmup_bert <= batch_idx:
set_bert_grad(colbert, True)
warmup_bert = None
this_batch_loss = 0.0
for batch in BatchSteps:
with amp.context():
try:
queries, passages, target_scores = batch
encoding = [queries, passages]
except:
encoding, target_scores = batch
encoding = [encoding.to(DEVICE)]
scores = colbert(*encoding)
if config.use_ib_negatives:
scores, ib_loss = scores
scores = scores.view(-1, config.nway)
if len(target_scores) and not config.ignore_scores:
target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE)
target_scores = target_scores * config.distillation_alpha
target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1)
log_scores = torch.nn.functional.log_softmax(scores, dim=-1)
loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores)
else:
loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)])
if config.use_ib_negatives:
if config.rank < 1:
print('\t\t\t\t', loss.item(), ib_loss.item())
loss += ib_loss
loss = loss / config.accumsteps
if config.rank < 1:
print_progress(scores)
amp.backward(loss)
this_batch_loss += loss.item()
train_loss = this_batch_loss if train_loss is None else train_loss
train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss
amp.step(colbert, optimizer, scheduler)
if config.rank < 1:
print_message(batch_idx, train_loss)
manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None)
if config.rank < 1:
print_message("#> Done with all triples!")
ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True)
return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one.
def set_bert_grad(colbert, value):
try:
for p in colbert.bert.parameters():
assert p.requires_grad is (not value)
p.requires_grad = value
except AttributeError:
set_bert_grad(colbert.module, value)