Spaces:
Running
Running
# Copyright 2024 the LlamaFactory team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from typing import Any, Dict, Sequence | |
import torch | |
from transformers import DataCollatorForSeq2Seq | |
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): | |
r""" | |
Data collator for pairwise data. | |
""" | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
r""" | |
Pads batched data to the longest sequence in the batch. | |
We generate 2 * n examples where the first n examples represent chosen examples and | |
the last n examples represent rejected examples. | |
""" | |
concatenated_features = [] | |
for key in ("chosen", "rejected"): | |
for feature in features: | |
target_feature = { | |
"input_ids": feature["{}_input_ids".format(key)], | |
"attention_mask": feature["{}_attention_mask".format(key)], | |
"labels": feature["{}_labels".format(key)], | |
} | |
if "pixel_values" in feature: | |
target_feature["pixel_values"] = feature["pixel_values"] | |
if "{}_token_type_ids".format(key) in feature: | |
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] | |
concatenated_features.append(target_feature) | |
return super().__call__(concatenated_features) | |
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): | |
r""" | |
Data collator for KTO data. | |
""" | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
target_features = [] | |
kl_features = [] | |
kto_tags = [] | |
for feature in features: | |
target_feature = { | |
"input_ids": feature["input_ids"], | |
"attention_mask": feature["attention_mask"], | |
"labels": feature["labels"], | |
} | |
kl_feature = { | |
"input_ids": feature["kl_input_ids"], | |
"attention_mask": feature["kl_attention_mask"], | |
"labels": feature["kl_labels"], | |
} | |
if "pixel_values" in feature: | |
target_feature["pixel_values"] = feature["pixel_values"] | |
if "token_type_ids" in feature: | |
target_feature["token_type_ids"] = feature["token_type_ids"] | |
kl_feature["token_type_ids"] = feature["kl_token_type_ids"] | |
target_features.append(target_feature) | |
kl_features.append(kl_feature) | |
kto_tags.append(feature["kto_tags"]) | |
batch = super().__call__(target_features) | |
kl_batch = super().__call__(kl_features) | |
batch["kl_input_ids"] = kl_batch["input_ids"] | |
batch["kl_attention_mask"] = kl_batch["attention_mask"] | |
batch["kl_labels"] = kl_batch["labels"] | |
if "token_type_ids" in batch: | |
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] | |
batch["kto_tags"] = torch.tensor(kto_tags) | |
return batch | |