model111 / larm /data /utils /dynamic_padding.py
LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import torch
from transformers import PreTrainedTokenizerBase
from typing import Dict, List, Any
class DynamicPaddingDataCollater:
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
if tokenizer.pad_token_id is None:
print("Warning: Tokenizer does not have a pad_token_id. Using 0 for input_ids and attention_mask padding.")
self.padding_value_input = 0
else:
self.padding_value_input = tokenizer.pad_token_id
# labels 的填充值
self.padding_value_label = tokenizer.pad_token_id
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
processed_features = []
for feature in features:
input_ids = feature["input_ids"]
completion_mask = feature["completion_mask"]
prompt_ids = [token for token, is_completion in zip(input_ids, completion_mask) if not is_completion]
label_ids = [token for token, is_completion in zip(input_ids, completion_mask) if is_completion]
processed_features.append({
"prompt_ids": prompt_ids,
"label_ids": label_ids,
"original": feature
})
max_prompt_len = max(len(f["prompt_ids"]) for f in processed_features)
max_label_len = max(len(f["label_ids"]) for f in processed_features)
padded_prompt_ids = []
padded_input_attention_mask = []
padded_label_ids = []
padded_labels_attention_mask = []
for feature in processed_features:
prompt_ids = feature["prompt_ids"]
label_ids = feature["label_ids"]
num_input_pads = max_prompt_len - len(prompt_ids)
padded_prompt_ids.append([self.padding_value_input] * num_input_pads + prompt_ids)
input_attention_mask = [1] * len(prompt_ids)
num_input_mask_pads = max_prompt_len - len(input_attention_mask)
padded_input_attention_mask.append([0] * num_input_mask_pads + input_attention_mask)
num_label_pads = max_label_len - len(label_ids)
padded_label_ids.append(label_ids + [self.padding_value_label] * num_label_pads)
labels_attention_mask = [1] * len(label_ids)
num_label_mask_pads = max_label_len - len(labels_attention_mask)
padded_labels_attention_mask.append(labels_attention_mask + [0] * num_label_mask_pads)
batch = {
"prompt_ids": torch.tensor(padded_prompt_ids, dtype=torch.long),
"prompt_attention_mask": torch.tensor(padded_input_attention_mask, dtype=torch.long),
"label_ids": torch.tensor(padded_label_ids, dtype=torch.long),
"label_attention_mask": torch.tensor(padded_labels_attention_mask, dtype=torch.long),
}
batch["raw_samples"] = [f["original"] for f in processed_features]
return batch