|
from typing import Dict, Sequence |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from xtuner.parallel.sequence import (get_sequence_parallel_world_size, |
|
pad_for_sequence_parallel) |
|
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX |
|
|
|
|
|
def internvl_collate_fn(instances: Sequence[Dict], |
|
pad_index: int = DEFAULT_PAD_TOKEN_INDEX, |
|
return_hf_format: bool = False, |
|
use_varlen_attn: bool = False): |
|
seq_parallel_world_size = get_sequence_parallel_world_size() |
|
|
|
input_ids, labels = [], [] |
|
has_image = any(inst.get('pixel_values') is not None for inst in instances) |
|
|
|
if use_varlen_attn: |
|
position_ids, cumulative_len = [], [] |
|
assert len(instances) == 1, ( |
|
f'If utilizing varlen attention, the batch size should be' |
|
f' set to 1, but got {len(instances)}') |
|
assert not has_image, 'Currently, it is not configured to ' |
|
'accommodate the use of varlen Attention in multimodal training' |
|
|
|
if has_image: |
|
pixel_values = [] |
|
frames_per_batch = [] |
|
|
|
for example in instances: |
|
input_ids.append(torch.LongTensor(example['input_ids'])) |
|
labels.append(torch.LongTensor(example['labels'])) |
|
if use_varlen_attn: |
|
cumulative_len.append(torch.IntTensor(example['cumulative_len'])) |
|
position_ids.append(torch.LongTensor(example['position_ids'])) |
|
|
|
if has_image: |
|
pixel_values.append(example['pixel_values']) |
|
|
|
ori_length = [len(ids) for ids in input_ids] |
|
if len(instances) > 1: |
|
input_ids = pad_sequence( |
|
input_ids, batch_first=True, padding_value=pad_index) |
|
labels = pad_sequence( |
|
labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
else: |
|
input_ids = torch.stack(input_ids) |
|
labels = torch.stack(labels) |
|
|
|
if use_varlen_attn: |
|
assert input_ids.size(1) % seq_parallel_world_size == 0 |
|
attention_mask = None |
|
position_ids = torch.stack(position_ids, dim=0) |
|
else: |
|
|
|
|
|
attention_mask = torch.zeros_like(input_ids).bool() |
|
for i, length in enumerate(ori_length): |
|
attention_mask[i, :length] = True |
|
|
|
bs, seq_len = input_ids.shape |
|
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) |
|
|
|
if seq_parallel_world_size > 1: |
|
input_ids = pad_for_sequence_parallel(input_ids, pad_index) |
|
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) |
|
position_ids = pad_for_sequence_parallel(position_ids, 0) |
|
if attention_mask is not None: |
|
attention_mask = pad_for_sequence_parallel(attention_mask, 0) |
|
|
|
if use_varlen_attn: |
|
max_seqlen = ( |
|
cumulative_len[0][1:] - |
|
cumulative_len[0][:-1]).max().item() |
|
data_dict = { |
|
'input_ids': input_ids, |
|
'cumulative_len': cumulative_len, |
|
'position_ids': position_ids, |
|
'labels': labels, |
|
'max_seqlen': max_seqlen |
|
} |
|
else: |
|
data_dict = { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask, |
|
'position_ids': position_ids, |
|
'labels': labels |
|
} |
|
|
|
if has_image: |
|
if all(x.shape == pixel_values[0].shape for x in pixel_values): |
|
pixel_values = torch.stack(pixel_values, dim=0) |
|
data_dict['frames_per_batch'] = frames_per_batch |
|
data_dict['pixel_values'] = pixel_values |
|
|
|
if return_hf_format: |
|
return data_dict |
|
else: |
|
return {'data': data_dict, 'data_samples': None} |