|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import sys |
|
import json |
|
from typing import Iterable, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from timm.utils import ModelEma |
|
from timm.utils import accuracy, ModelEma |
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
|
from datasets import get_sentencepiece_model_for_beit3 |
|
|
|
import utils |
|
|
|
|
|
class TaskHandler(object): |
|
def __init__(self) -> None: |
|
self.metric_logger = None |
|
self.split = None |
|
|
|
def train_batch(self, model, **kwargs): |
|
raise NotImplementedError() |
|
|
|
def eval_batch(self, model, **kwargs): |
|
raise NotImplementedError() |
|
|
|
def before_eval(self, metric_logger, data_loader, **kwargs): |
|
self.metric_logger = metric_logger |
|
self.split = data_loader.dataset.split |
|
|
|
def after_eval(self, **kwargs): |
|
raise NotImplementedError() |
|
|
|
|
|
class NLVR2Handler(TaskHandler): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.criterion = torch.nn.CrossEntropyLoss() |
|
|
|
def train_batch(self, model, image, image2, language_tokens, padding_mask, label): |
|
logits = model( |
|
image_a=image, image_b=image2, |
|
text_description=language_tokens, |
|
padding_mask=padding_mask) |
|
acc = (logits.max(-1)[-1] == label).float().mean() |
|
return { |
|
"loss": self.criterion(input=logits, target=label), |
|
"acc": acc, |
|
} |
|
|
|
def eval_batch(self, model, image, image2, language_tokens, padding_mask, label): |
|
logits = model( |
|
image_a=image, image_b=image2, |
|
text_description=language_tokens, |
|
padding_mask=padding_mask) |
|
batch_size = language_tokens.shape[0] |
|
acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size |
|
self.metric_logger.meters['acc'].update(acc.item(), n=batch_size) |
|
|
|
def after_eval(self, **kwargs): |
|
print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc)) |
|
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc" |
|
|
|
|
|
class ImageNetHandler(TaskHandler): |
|
def __init__(self, args) -> None: |
|
super().__init__() |
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None |
|
if mixup_active: |
|
|
|
self.criterion = SoftTargetCrossEntropy() |
|
elif args.label_smoothing > 0.: |
|
self.criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing) |
|
else: |
|
self.criterion = torch.nn.CrossEntropyLoss() |
|
|
|
def train_batch(self, model, image, label): |
|
logits = model(image=image) |
|
return { |
|
"loss": self.criterion(logits, label), |
|
} |
|
|
|
def eval_batch(self, model, image, label): |
|
logits = model(image=image) |
|
batch_size = image.shape[0] |
|
acc1, acc5 = accuracy(logits, label, topk=(1, 5)) |
|
self.metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
|
self.metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
|
|
|
def after_eval(self, **kwargs): |
|
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' |
|
.format(top1=self.metric_logger.acc1, top5=self.metric_logger.acc5)) |
|
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc1" |
|
|
|
|
|
class RetrievalHandler(TaskHandler): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.image_feats = [] |
|
self.text_feats = [] |
|
self.image_ids = [] |
|
self.metric_logger = None |
|
|
|
def train_batch(self, model, image, language_tokens, padding_mask, image_id): |
|
loss, vision_cls, language_cls = model( |
|
image=image, text_description=language_tokens, padding_mask=padding_mask) |
|
return { |
|
"loss": loss, |
|
} |
|
|
|
def before_eval(self, metric_logger, **kwargs): |
|
self.image_feats.clear() |
|
self.text_feats.clear() |
|
self.image_ids.clear() |
|
self.metric_logger = metric_logger |
|
|
|
def eval_batch(self, model, image, language_tokens, padding_mask, image_id): |
|
vision_cls, _ = model(image=image, only_infer=True) |
|
_, language_cls = model( |
|
text_description=language_tokens, padding_mask=padding_mask, only_infer=True) |
|
|
|
self.image_feats.append(vision_cls.clone()) |
|
self.text_feats.append(language_cls.clone()) |
|
self.image_ids.append(image_id.clone()) |
|
|
|
def after_eval(self, **kwargs): |
|
image_feats = {} |
|
for feats, ids in zip(self.image_feats, self.image_ids): |
|
for i, _idx in enumerate(ids): |
|
idx = _idx.item() |
|
if idx not in image_feats: |
|
image_feats[idx] = feats[i] |
|
|
|
tiids = torch.cat(self.image_ids, dim=0) |
|
iids = [] |
|
sorted_tensors = [] |
|
for key in sorted(image_feats.keys()): |
|
sorted_tensors.append(image_feats[key].view(1, -1)) |
|
iids.append(key) |
|
|
|
image_cls_feats = torch.cat(sorted_tensors, dim=0) |
|
text_cls_feats = torch.cat(self.text_feats, dim=0) |
|
|
|
scores = image_cls_feats @ text_cls_feats.t() |
|
iids = torch.LongTensor(iids).to(scores.device) |
|
|
|
print("scores: {}".format(scores.size())) |
|
print("iids: {}".format(iids.size())) |
|
print("tiids: {}".format(tiids.size())) |
|
|
|
topk10 = scores.topk(10, dim=1) |
|
topk5 = scores.topk(5, dim=1) |
|
topk1 = scores.topk(1, dim=1) |
|
|
|
topk10_iids = tiids[topk10.indices] |
|
topk5_iids = tiids[topk5.indices] |
|
topk1_iids = tiids[topk1.indices] |
|
|
|
tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean() |
|
tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean() |
|
tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean() |
|
|
|
topk10 = scores.topk(10, dim=0) |
|
topk5 = scores.topk(5, dim=0) |
|
topk1 = scores.topk(1, dim=0) |
|
topk10_iids = iids[topk10.indices] |
|
topk5_iids = iids[topk5.indices] |
|
topk1_iids = iids[topk1.indices] |
|
|
|
ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean() |
|
ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean() |
|
ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean() |
|
|
|
eval_result = { |
|
"tr_r10": tr_r10.item() * 100.0, |
|
"tr_r5": tr_r5.item() * 100.0, |
|
"tr_r1": tr_r1.item() * 100.0, |
|
"ir_r10": ir_r10.item() * 100.0, |
|
"ir_r5": ir_r5.item() * 100.0, |
|
"ir_r1": ir_r1.item() * 100.0, |
|
"average_score": 100.0 * (tr_r1 + tr_r5 + tr_r10 + ir_r1 + ir_r5 + ir_r10).item() / 6.0, |
|
} |
|
|
|
print('* Eval result = %s' % json.dumps(eval_result)) |
|
return eval_result, "average_score" |
|
|
|
|
|
class VQAHandler(TaskHandler): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.predictions = [] |
|
self.criterion = nn.BCEWithLogitsLoss(reduction='mean') |
|
self.label2ans = None |
|
|
|
def train_batch(self, model, image, language_tokens, padding_mask, labels): |
|
logits = model( |
|
image=image, question=language_tokens, |
|
padding_mask=padding_mask) |
|
return { |
|
"loss": self.criterion(input=logits.float(), target=labels.float()) * labels.shape[1], |
|
} |
|
|
|
def before_eval(self, metric_logger, data_loader, **kwargs): |
|
self.predictions.clear() |
|
self.metric_logger = metric_logger |
|
self.label2ans = data_loader.dataset.label2ans |
|
|
|
def eval_batch(self, model, image, language_tokens, padding_mask, labels=None, qid=None): |
|
logits = model( |
|
image=image, question=language_tokens, |
|
padding_mask=padding_mask) |
|
batch_size = language_tokens.shape[0] |
|
if labels is not None: |
|
scores = utils.VQAScore()(logits, labels) * 100.0 |
|
self.metric_logger.meters['score'].update(scores.item(), n=batch_size) |
|
else: |
|
_, preds = logits.max(-1) |
|
for image_id, pred in zip(qid, preds): |
|
self.predictions.append({ |
|
"question_id": image_id.item(), |
|
"answer": self.label2ans[pred.item()], |
|
}) |
|
|
|
def after_eval(self, **kwargs): |
|
if len(self.predictions) == 0: |
|
print('* Score {score.global_avg:.3f}'.format(score=self.metric_logger.score)) |
|
return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "score" |
|
else: |
|
return self.predictions, "prediction" |
|
|
|
|
|
class CaptioningHandler(TaskHandler): |
|
def __init__(self, args) -> None: |
|
super().__init__() |
|
self.predictions = [] |
|
self.criterion = utils.BertCaptioningLoss(args.label_smoothing, args.drop_worst_ratio, args.drop_worst_after) |
|
self.tokenizer = get_sentencepiece_model_for_beit3(args) |
|
self.num_beams = args.num_beams |
|
self.max_len = args.num_max_bpe_tokens |
|
self.length_penalty = args.length_penalty |
|
self.vocab_size = args.vocab_size |
|
|
|
def train_batch(self, model, image, language_tokens, masked_tokens, language_masked_pos, padding_mask, image_id, global_step): |
|
logits, _ = model( |
|
image=image, text_ids=masked_tokens, padding_mask=padding_mask, language_masked_pos=language_masked_pos, image_id=image_id) |
|
masked_labels = language_tokens[language_masked_pos.bool()] |
|
score = torch.max(logits, -1)[1].data == masked_labels |
|
acc = torch.sum(score.float()) / torch.sum(language_masked_pos) |
|
return { |
|
"loss": self.criterion(logits, masked_labels, global_step), |
|
"acc": acc |
|
} |
|
|
|
def before_eval(self, metric_logger, data_loader, **kwargs): |
|
self.predictions.clear() |
|
self.metric_logger = metric_logger |
|
|
|
def eval_batch(self, model, image, image_id=None): |
|
cur_len = 2 |
|
num_keep_best = 1 |
|
TOPN_PER_BEAM = 3 |
|
|
|
batch_size = image.size(0) |
|
mask_id = self.tokenizer.mask_token_id |
|
cls_id = self.tokenizer.cls_token_id |
|
pad_id = self.tokenizer.pad_token_id |
|
sep_id = self.tokenizer.sep_token_id |
|
eos_token_ids = [sep_id] |
|
|
|
cls_ids = torch.full( |
|
(batch_size, 1), cls_id, dtype=torch.long, device=image.device |
|
) |
|
mask_ids = torch.full( |
|
(batch_size, 1), mask_id, dtype=torch.long, device=image.device |
|
) |
|
cur_input_ids = torch.cat([cls_ids, mask_ids], dim=1) |
|
tmp_ids = torch.full( |
|
(batch_size, self.max_len-1), mask_id, dtype=torch.long, device=image.device |
|
) |
|
decoding_results = torch.cat([cls_ids, tmp_ids], dim=1) |
|
|
|
|
|
cur_input_ids = cur_input_ids.unsqueeze(1).expand(batch_size, self.num_beams, cur_len) |
|
cur_input_ids = cur_input_ids.contiguous().view(batch_size * self.num_beams, cur_len) |
|
decoding_results = decoding_results.unsqueeze(1).expand(batch_size, self.num_beams, self.max_len) |
|
decoding_results = decoding_results.contiguous().view(batch_size * self.num_beams, self.max_len) |
|
image = image.unsqueeze(1).expand(batch_size, self.num_beams, image.size(-3), image.size(-2), image.size(-1)) |
|
image = image.contiguous().view(batch_size * self.num_beams, image.size(-3), image.size(-2), image.size(-1)) |
|
|
|
generated_hyps = [ |
|
utils.BeamHypotheses( |
|
num_keep_best, self.max_len, length_penalty=self.length_penalty, early_stopping=False |
|
) for _ in range(batch_size) |
|
] |
|
|
|
beam_scores = torch.zeros((batch_size, self.num_beams), dtype=torch.float, device=cur_input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view(-1) |
|
|
|
|
|
done = [False for _ in range(batch_size)] |
|
incremental_state = {} |
|
|
|
while cur_len <= self.max_len: |
|
next_token_idx = 1 |
|
padding_masks = torch.full( |
|
cur_input_ids.shape, 0, dtype=torch.long, device=image.device |
|
) |
|
input_image = image |
|
if cur_len != 2: |
|
input_image = None |
|
|
|
outputs, incremental_state_next = model( |
|
image=input_image, text_ids=cur_input_ids, language_masked_pos=None, |
|
padding_mask=padding_masks, text_len=cur_len, incremental_state=incremental_state) |
|
incremental_state = incremental_state_next |
|
|
|
|
|
scores = outputs[:, next_token_idx, :] |
|
scores = F.log_softmax(scores, dim=-1) |
|
assert scores.size() == (batch_size * self.num_beams, self.vocab_size) |
|
|
|
_scores = scores + beam_scores[:, None].expand_as(scores) |
|
|
|
_scores = _scores.view(batch_size, self.num_beams * self.vocab_size) |
|
next_scores, next_words = torch.topk(_scores, TOPN_PER_BEAM * self.num_beams, dim=1, largest=True, sorted=True) |
|
assert next_scores.size() == next_words.size() == (batch_size, TOPN_PER_BEAM * self.num_beams) |
|
|
|
|
|
|
|
next_batch_beam = [] |
|
|
|
for batch_ex in range(batch_size): |
|
|
|
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item()) |
|
if done[batch_ex]: |
|
next_batch_beam.extend([(0, pad_id, 0)] * self.num_beams) |
|
continue |
|
|
|
|
|
next_sent_beam = [] |
|
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]): |
|
|
|
beam_id = idx // self.vocab_size |
|
word_id = idx % self.vocab_size |
|
|
|
|
|
if (word_id.item() in eos_token_ids and cur_len + 1 <= self.max_len) or (cur_len + 1 == self.max_len): |
|
generated_hyps[batch_ex].add( |
|
decoding_results[batch_ex * self.num_beams + beam_id, :cur_len].clone(), score.item() |
|
) |
|
else: |
|
next_sent_beam.append((score, word_id, batch_ex * self.num_beams + beam_id)) |
|
|
|
if len(next_sent_beam) == self.num_beams: |
|
break |
|
|
|
|
|
if cur_len + 1 == self.max_len: |
|
assert len(next_sent_beam) == 0 |
|
else: |
|
assert len(next_sent_beam) == self.num_beams |
|
|
|
if len(next_sent_beam) == 0: |
|
next_sent_beam = [(0, pad_id, 0)] * self.num_beams |
|
next_batch_beam.extend(next_sent_beam) |
|
assert len(next_batch_beam) == self.num_beams * (batch_ex + 1) |
|
|
|
|
|
assert len(next_batch_beam) == batch_size * self.num_beams |
|
beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) |
|
beam_words = cur_input_ids.new([x[1] for x in next_batch_beam]) |
|
beam_idx = cur_input_ids.new([x[2] for x in next_batch_beam]) |
|
|
|
|
|
cur_input_ids = cur_input_ids[beam_idx, :] |
|
decoding_results = decoding_results[beam_idx, :] |
|
for module in incremental_state: |
|
for key in incremental_state[module]: |
|
result = incremental_state[module][key].index_select(0, beam_idx) |
|
incremental_state[module][key] = result[:,:,:-1,:] |
|
|
|
next_ids = torch.full( |
|
(batch_size * self.num_beams, 1), mask_id, dtype=torch.long, device=image.device |
|
) |
|
cur_input_ids = torch.cat([beam_words.unsqueeze(1), next_ids], dim=1) |
|
decoding_results[:, cur_len-1] = beam_words |
|
|
|
cur_len = cur_len + 1 |
|
|
|
if all(done): |
|
break |
|
|
|
|
|
tgt_len = torch.ones(batch_size, num_keep_best, dtype=torch.long) |
|
logprobs = torch.zeros(batch_size, num_keep_best, |
|
dtype=torch.float).fill_(-1e5).to(cur_input_ids.device) |
|
all_best = [] |
|
|
|
for i, hypotheses in enumerate(generated_hyps): |
|
best = [] |
|
hyp_scores = torch.tensor([x[0] for x in hypotheses.hyp]) |
|
_, best_indices = torch.topk(hyp_scores, |
|
min(num_keep_best, len(hyp_scores)), largest=True) |
|
for best_idx, hyp_idx in enumerate(best_indices): |
|
conf, best_hyp = hypotheses.hyp[hyp_idx] |
|
best.append(best_hyp) |
|
logprobs[i, best_idx] = conf |
|
tgt_len[i, best_idx] = len(best_hyp) + 1 |
|
all_best.append(best) |
|
|
|
|
|
decoded = cur_input_ids.new(batch_size, num_keep_best, self.max_len).fill_(pad_id) |
|
for batch_idx, best in enumerate(all_best): |
|
for best_idx, hypo in enumerate(best): |
|
decoded[batch_idx, best_idx, : tgt_len[batch_idx, best_idx] - 1] = hypo |
|
decoded[batch_idx, best_idx, tgt_len[batch_idx, best_idx] - 1] = eos_token_ids[0] |
|
|
|
captions = self.tokenizer.batch_decode(decoded.squeeze(1), skip_special_tokens=True) |
|
for qid, pred in zip(image_id, captions): |
|
self.predictions.append({ |
|
"image_id": qid.item(), |
|
"caption": pred, |
|
}) |
|
|
|
def after_eval(self, **kwargs): |
|
return self.predictions, "prediction" |
|
|
|
|
|
def get_handler(args): |
|
if args.task == "nlvr2": |
|
return NLVR2Handler() |
|
elif args.task == "vqav2": |
|
return VQAHandler() |
|
elif args.task in ("flickr30k", "coco_retrieval"): |
|
return RetrievalHandler() |
|
elif args.task in ("coco_captioning", "nocaps"): |
|
return CaptioningHandler(args) |
|
elif args.task in ("imagenet"): |
|
return ImageNetHandler(args) |
|
else: |
|
raise NotImplementedError("Sorry, %s is not support." % args.task) |
|
|
|
|
|
def train_one_epoch( |
|
model: torch.nn.Module, data_loader: Iterable, |
|
optimizer: torch.optim.Optimizer, device: torch.device, |
|
handler: TaskHandler, epoch: int, start_steps: int, |
|
lr_schedule_values: list, loss_scaler, max_norm: float = 0, |
|
update_freq: int = 1, model_ema: Optional[ModelEma] = None, |
|
log_writer: Optional[utils.TensorboardLogger] = None, |
|
task = None, mixup_fn=None, |
|
): |
|
model.train(True) |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = 10 |
|
|
|
if loss_scaler is None: |
|
model.zero_grad() |
|
model.micro_steps = 0 |
|
else: |
|
optimizer.zero_grad() |
|
|
|
for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
|
step = data_iter_step // update_freq |
|
global_step = start_steps + step |
|
|
|
if lr_schedule_values is not None and data_iter_step % update_freq == 0: |
|
for i, param_group in enumerate(optimizer.param_groups): |
|
if lr_schedule_values is not None: |
|
param_group["lr"] = lr_schedule_values[global_step] * param_group["lr_scale"] |
|
|
|
for tensor_key in data.keys(): |
|
data[tensor_key] = data[tensor_key].to(device, non_blocking=True) |
|
|
|
if loss_scaler is None and tensor_key.startswith("image"): |
|
data[tensor_key] = data[tensor_key].half() |
|
|
|
|
|
if mixup_fn is not None: |
|
data["image"], data["label"] = mixup_fn(data["image"], data["label"]) |
|
|
|
if task in ["coco_captioning", "nocaps"]: |
|
data["global_step"] = global_step |
|
|
|
if loss_scaler is None: |
|
results = handler.train_batch(model, **data) |
|
else: |
|
with torch.cuda.amp.autocast(): |
|
results = handler.train_batch(model, **data) |
|
|
|
loss = results.pop("loss") |
|
loss_value = loss.item() |
|
|
|
if not math.isfinite(loss_value): |
|
print("Loss is {}, stopping training".format(loss_value)) |
|
sys.exit(1) |
|
|
|
if loss_scaler is None: |
|
loss /= update_freq |
|
model.backward(loss) |
|
model.step() |
|
|
|
if (data_iter_step + 1) % update_freq == 0: |
|
|
|
|
|
if model_ema is not None: |
|
model_ema.update(model) |
|
grad_norm = None |
|
loss_scale_value = utils.get_loss_scale_for_deepspeed(model) |
|
else: |
|
|
|
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
|
loss /= update_freq |
|
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, |
|
parameters=model.parameters(), create_graph=is_second_order, |
|
update_grad=(data_iter_step + 1) % update_freq == 0) |
|
if (data_iter_step + 1) % update_freq == 0: |
|
optimizer.zero_grad() |
|
if model_ema is not None: |
|
model_ema.update(model) |
|
loss_scale_value = loss_scaler.state_dict()["scale"] |
|
|
|
torch.cuda.synchronize() |
|
|
|
metric_logger.update(loss=loss_value) |
|
metric_logger.update(loss_scale=loss_scale_value) |
|
min_lr = 10. |
|
max_lr = 0. |
|
for group in optimizer.param_groups: |
|
min_lr = min(min_lr, group["lr"]) |
|
max_lr = max(max_lr, group["lr"]) |
|
|
|
metric_logger.update(lr=max_lr) |
|
metric_logger.update(min_lr=min_lr) |
|
weight_decay_value = None |
|
for group in optimizer.param_groups: |
|
if group["weight_decay"] > 0: |
|
weight_decay_value = group["weight_decay"] |
|
metric_logger.update(weight_decay=weight_decay_value) |
|
metric_logger.update(grad_norm=grad_norm) |
|
|
|
if log_writer is not None: |
|
kwargs = { |
|
"loss": loss_value, |
|
} |
|
for key in results: |
|
kwargs[key] = results[key] |
|
log_writer.update(head="train", **kwargs) |
|
|
|
kwargs = { |
|
"loss_scale": loss_scale_value, |
|
"lr": max_lr, |
|
"min_lr": min_lr, |
|
"weight_decay": weight_decay_value, |
|
"grad_norm": grad_norm, |
|
} |
|
log_writer.update(head="opt", **kwargs) |
|
log_writer.set_step() |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
def evaluate(data_loader, model, device, handler): |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
|
|
|
|
model.eval() |
|
handler.before_eval(metric_logger=metric_logger, data_loader=data_loader) |
|
|
|
for data in metric_logger.log_every(data_loader, 10, header): |
|
for tensor_key in data.keys(): |
|
data[tensor_key] = data[tensor_key].to(device, non_blocking=True) |
|
|
|
with torch.cuda.amp.autocast(): |
|
handler.eval_batch(model=model, **data) |
|
|
|
|
|
metric_logger.synchronize_between_processes() |
|
|
|
return handler.after_eval() |
|
|