import torch class LoadTextTokens(object): def __init__(self, tokenizer, max_text_len=40, padding='do_not_pad'): self.tokenizer = tokenizer self.max_text_len = max_text_len self.padding = padding def descriptions_to_text_tokens(self, target, begin_token): target_encoding = self.tokenizer( target, padding=self.padding, add_special_tokens=False, truncation=True, max_length=self.max_text_len) need_predict = [1] * len(target_encoding['input_ids']) payload = target_encoding['input_ids'] if len(payload) > self.max_text_len - 2: payload = payload[-(self.max_text_len - 2):] need_predict = payload[-(self.max_text_len - 2):] input_ids = [begin_token] + payload + [self.tokenizer.sep_token_id] need_predict = [0] + need_predict + [1] data = { 'text_tokens': torch.tensor(input_ids), 'text_lengths': len(input_ids), 'need_predict': torch.tensor(need_predict), } return data def __call__(self, object_descriptions, box_features, begin_token): text_tokens = [] text_lengths = [] need_predict = [] for description in object_descriptions: tokens = self.descriptions_to_text_tokens(description, begin_token) text_tokens.append(tokens['text_tokens']) text_lengths.append(tokens['text_lengths']) need_predict.append(tokens['need_predict']) text_tokens = torch.cat(self.collate(text_tokens), dim=0).to(box_features.device) text_lengths = torch.tensor(text_lengths).to(box_features.device) need_predict = torch.cat(self.collate(need_predict), dim=0).to(box_features.device) assert text_tokens.dim() == 2 and need_predict.dim() == 2 data = {'text_tokens': text_tokens, 'text_lengths': text_lengths, 'need_predict': need_predict} return data def collate(self, batch): if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0: if not all(b.shape == batch[0].shape for b in batch[1:]): assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:]) shape = torch.tensor([b.shape for b in batch]) max_shape = tuple(shape.max(dim=0)[0].tolist()) batch2 = [] for b in batch: if any(c < m for c, m in zip(b.shape, max_shape)): b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device) if b.dim() == 1: b2[:b.shape[0]] = b elif b.dim() == 2: b2[:b.shape[0], :b.shape[1]] = b elif b.dim() == 3: b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b else: raise NotImplementedError b = b2 batch2.append(b[None, ...]) else: batch2 = [] for b in batch: batch2.append(b[None, ...]) return batch2 else: raise NotImplementedError