| | |
| | |
| |
|
| | import torch |
| |
|
| | from internlm.core.context import global_context as gpc |
| |
|
| | DATASET_TYPE_IDS_MAP = {"vision": 0} |
| |
|
| |
|
| | def get_dataset_type_id(path): |
| | import re |
| |
|
| | match_idxes = [] |
| | for key, idx in DATASET_TYPE_IDS_MAP.items(): |
| | if re.search(rf"/[z_]*{key}/", path): |
| | match_idxes.append(idx) |
| | assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" |
| | return match_idxes[0] |
| |
|
| |
|
| | def unpack_data(input_ids, cu_seqlens): |
| | """ |
| | input_ids: (n, packed_length) |
| | Return: |
| | output: (batch_size, max_length) |
| | """ |
| |
|
| | bsz = input_ids.shape[0] |
| |
|
| | num_sequence = gpc.config.data["micro_bsz"] |
| |
|
| | outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) |
| |
|
| | for i in range(bsz): |
| | output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) |
| | cu_seqlens_slice = cu_seqlens[i] |
| | for j in range(num_sequence): |
| | seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] |
| | output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] |
| | outputs[i] = output |
| |
|
| | if bsz == 1: |
| | outputs = outputs.squeeze(0) |
| |
|
| | return outputs |
| |
|