Sirus1's picture
Duplicate from TencentARC/VLog
6f6830f
raw
history blame
No virus
3.28 kB
import torch
class LoadTextTokens(object):
def __init__(self, tokenizer, max_text_len=40, padding='do_not_pad'):
self.tokenizer = tokenizer
self.max_text_len = max_text_len
self.padding = padding
def descriptions_to_text_tokens(self, target, begin_token):
target_encoding = self.tokenizer(
target, padding=self.padding,
add_special_tokens=False,
truncation=True, max_length=self.max_text_len)
need_predict = [1] * len(target_encoding['input_ids'])
payload = target_encoding['input_ids']
if len(payload) > self.max_text_len - 2:
payload = payload[-(self.max_text_len - 2):]
need_predict = payload[-(self.max_text_len - 2):]
input_ids = [begin_token] + payload + [self.tokenizer.sep_token_id]
need_predict = [0] + need_predict + [1]
data = {
'text_tokens': torch.tensor(input_ids),
'text_lengths': len(input_ids),
'need_predict': torch.tensor(need_predict),
}
return data
def __call__(self, object_descriptions, box_features, begin_token):
text_tokens = []
text_lengths = []
need_predict = []
for description in object_descriptions:
tokens = self.descriptions_to_text_tokens(description, begin_token)
text_tokens.append(tokens['text_tokens'])
text_lengths.append(tokens['text_lengths'])
need_predict.append(tokens['need_predict'])
text_tokens = torch.cat(self.collate(text_tokens), dim=0).to(box_features.device)
text_lengths = torch.tensor(text_lengths).to(box_features.device)
need_predict = torch.cat(self.collate(need_predict), dim=0).to(box_features.device)
assert text_tokens.dim() == 2 and need_predict.dim() == 2
data = {'text_tokens': text_tokens,
'text_lengths': text_lengths,
'need_predict': need_predict}
return data
def collate(self, batch):
if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0:
if not all(b.shape == batch[0].shape for b in batch[1:]):
assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:])
shape = torch.tensor([b.shape for b in batch])
max_shape = tuple(shape.max(dim=0)[0].tolist())
batch2 = []
for b in batch:
if any(c < m for c, m in zip(b.shape, max_shape)):
b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device)
if b.dim() == 1:
b2[:b.shape[0]] = b
elif b.dim() == 2:
b2[:b.shape[0], :b.shape[1]] = b
elif b.dim() == 3:
b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b
else:
raise NotImplementedError
b = b2
batch2.append(b[None, ...])
else:
batch2 = []
for b in batch:
batch2.append(b[None, ...])
return batch2
else:
raise NotImplementedError