|
import torch |
|
import copy |
|
import argparse |
|
from dataclasses import dataclass |
|
|
|
import transformers |
|
import math |
|
from torch.utils.data import Sampler |
|
import torch.distributed as dist |
|
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration |
|
|
|
|
|
class Collator(object): |
|
|
|
def __init__(self, args, tokenizer): |
|
self.args = args |
|
self.only_train_response = args.only_train_response |
|
self.tokenizer = tokenizer |
|
if self.tokenizer.pad_token_id is None: |
|
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id |
|
|
|
|
|
def __call__(self, batch): |
|
|
|
input_texts = [d["input_ids"] for d in batch] |
|
full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch] |
|
|
|
inputs = self.tokenizer( |
|
text = full_texts, |
|
text_target = input_texts, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_attention_mask=True, |
|
) |
|
labels = copy.deepcopy(inputs["input_ids"]) |
|
if self.only_train_response: |
|
|
|
labels[labels == self.tokenizer.pad_token_id] = -100 |
|
|
|
labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100 |
|
|
|
inputs["labels"] = labels |
|
|
|
|
|
return inputs |
|
|
|
|
|
|
|
class TestCollator(object): |
|
|
|
def __init__(self, args, tokenizer): |
|
self.args = args |
|
self.tokenizer = tokenizer |
|
if self.tokenizer.pad_token_id is None: |
|
self.tokenizer.pad_token_id = 0 |
|
|
|
if isinstance(self.tokenizer, LlamaTokenizer): |
|
|
|
self.tokenizer.padding_side = "left" |
|
|
|
def __call__(self, batch): |
|
|
|
input_texts = [d["input_ids"] for d in batch] |
|
targets = [d["labels"] for d in batch] |
|
inputs = self.tokenizer( |
|
text=input_texts, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_attention_mask=True, |
|
) |
|
|
|
return (inputs, targets) |
|
|
|
|