import os import torch from colbert.utils.runs import Run from colbert.utils.utils import print_message, save_checkpoint from colbert.parameters import SAVED_CHECKPOINTS def print_progress(scores): positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) def manage_checkpoints(args, colbert, optimizer, batch_idx): arguments = args.input_arguments.__dict__ path = os.path.join(Run.path, 'checkpoints') if not os.path.exists(path): os.mkdir(path) if batch_idx % 2000 == 0: name = os.path.join(path, "colbert.dnn") save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) if batch_idx in SAVED_CHECKPOINTS: name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)