Pratik Dwivedi
trainer commit (#1)
56d31bf
raw
history blame contribute delete
No virus
3.06 kB
import torch
import transformers
import json
from dataclasses import dataclass
from typing import Dict, Sequence
from tqdm import tqdm
from torch.utils.data import Dataset
class ChatDataset(Dataset):
def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
super(ChatDataset, self).__init__()
data = []
with open(data_path, "r") as file:
for line in file:
try:
data.append(json.loads(line))
except Exception as e:
print("json processing exception", e)
continue
data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass
class DataCollatorForChatDataset(object):
"""
Collate examples for supervised fine-tuning.
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "input_ids"))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
class ChatDataModule():
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):
self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
"""
Preprocess the data by tokenizing.
"""
all_input_ids = []
all_label_ids = []
tokenizer.use_default_system_prompt = False
print("Tokenizing dataset...")
for conv in tqdm(conversations):
current_conv = conv["messages"]
tokenized_responses = []
for msg in current_conv:
if msg["role"] == "assistant":
tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))
tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
all_input_ids.append(torch.LongTensor(tokenized_conv))
return dict(input_ids=all_input_ids, labels=all_input_ids)