|
|
import torch |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
class DynamicPaddingDataCollater: |
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase): |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
if tokenizer.pad_token_id is None: |
|
|
print("Warning: Tokenizer does not have a pad_token_id. Using 0 for input_ids and attention_mask padding.") |
|
|
self.padding_value_input = 0 |
|
|
else: |
|
|
self.padding_value_input = tokenizer.pad_token_id |
|
|
|
|
|
|
|
|
self.padding_value_label = tokenizer.pad_token_id |
|
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
|
|
|
|
|
processed_features = [] |
|
|
for feature in features: |
|
|
input_ids = feature["input_ids"] |
|
|
completion_mask = feature["completion_mask"] |
|
|
|
|
|
prompt_ids = [token for token, is_completion in zip(input_ids, completion_mask) if not is_completion] |
|
|
|
|
|
label_ids = [token for token, is_completion in zip(input_ids, completion_mask) if is_completion] |
|
|
|
|
|
processed_features.append({ |
|
|
"prompt_ids": prompt_ids, |
|
|
"label_ids": label_ids, |
|
|
|
|
|
"original": feature |
|
|
}) |
|
|
|
|
|
max_prompt_len = max(len(f["prompt_ids"]) for f in processed_features) |
|
|
max_label_len = max(len(f["label_ids"]) for f in processed_features) |
|
|
|
|
|
padded_prompt_ids = [] |
|
|
padded_input_attention_mask = [] |
|
|
padded_label_ids = [] |
|
|
padded_labels_attention_mask = [] |
|
|
|
|
|
for feature in processed_features: |
|
|
|
|
|
prompt_ids = feature["prompt_ids"] |
|
|
label_ids = feature["label_ids"] |
|
|
|
|
|
|
|
|
num_input_pads = max_prompt_len - len(prompt_ids) |
|
|
padded_prompt_ids.append([self.padding_value_input] * num_input_pads + prompt_ids) |
|
|
|
|
|
input_attention_mask = [1] * len(prompt_ids) |
|
|
num_input_mask_pads = max_prompt_len - len(input_attention_mask) |
|
|
padded_input_attention_mask.append([0] * num_input_mask_pads + input_attention_mask) |
|
|
|
|
|
num_label_pads = max_label_len - len(label_ids) |
|
|
padded_label_ids.append(label_ids + [self.padding_value_label] * num_label_pads) |
|
|
|
|
|
labels_attention_mask = [1] * len(label_ids) |
|
|
num_label_mask_pads = max_label_len - len(labels_attention_mask) |
|
|
padded_labels_attention_mask.append(labels_attention_mask + [0] * num_label_mask_pads) |
|
|
|
|
|
batch = { |
|
|
"prompt_ids": torch.tensor(padded_prompt_ids, dtype=torch.long), |
|
|
"prompt_attention_mask": torch.tensor(padded_input_attention_mask, dtype=torch.long), |
|
|
"label_ids": torch.tensor(padded_label_ids, dtype=torch.long), |
|
|
"label_attention_mask": torch.tensor(padded_labels_attention_mask, dtype=torch.long), |
|
|
} |
|
|
|
|
|
batch["raw_samples"] = [f["original"] for f in processed_features] |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|