Spaces:
Runtime error
Runtime error
File size: 1,261 Bytes
75ba0e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import string
import math
import torch
from data import data_utils
def get_symbols_to_strip_from_output(generator):
if hasattr(generator, "symbols_to_strip_from_output"):
return generator.symbols_to_strip_from_output
else:
return {generator.bos, generator.eos}
def decode_fn(x, tgt_dict, bpe, generator, tokenizer=None):
x = tgt_dict.string(x.int().cpu(), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator))
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
def eval_caption(task, generator, models, sample):
transtab = str.maketrans({key: None for key in string.punctuation})
hypos = task.inference_step(generator, models, sample)
results = []
for i, sample_id in enumerate(sample["id"].tolist()):
detok_hypo_str = decode_fn(hypos[i][0]["tokens"], task.tgt_dict, task.bpe, generator)
results.append({"image_id": str(sample_id), "caption": detok_hypo_str.translate(transtab).strip()})
return results, None
def eval_step(task, generator, models, sample):
if task.cfg._name == 'caption':
return eval_caption(task, generator, models, sample)
else:
raise NotImplementedError
|