Ahma-7B / EasyLM /data.py
aapot
Update optimizers
947b4f4
import dataclasses
import pprint
import time
from functools import partial
import json
import base64
from multiprocessing import Pool
import h5py
import mlxu
from ml_collections.config_dict import config_dict
from ml_collections import ConfigDict
from tqdm import tqdm, trange
import numpy as np
from datasets import load_dataset, load_from_disk
class DatasetFactory(object):
""" Datset builder class. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.type = 'huggingface'
config.text_processor = TextProcessor.get_default_config()
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
config.json_dataset = JsonDataset.get_default_config()
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
@classmethod
def load_dataset(cls, config, tokenizer, **kwargs):
config = cls.get_default_config(config)
text_processor = TextProcessor(config.text_processor, tokenizer)
if config.type == 'huggingface':
return HuggingfaceDataset(
config.huggingface_dataset, tokenizer, text_processor, **kwargs
)
elif config.type == 'json':
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
else:
raise ValueError(f'Unknown dataset type: {config.type}')
def __init__(self):
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
class TextProcessor(object):
""" Example processor that converts a dictionary of texts into tokens. """
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.fields_from_example = ''
config.fields = ''
config.subfield_separator = ' '
config.add_bos_token = True
config.add_eos_token = True
config.prepend_text = ''
config.base64_token_dtype = 'i4'
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer):
self.config = self.get_default_config(config)
assert self.config.fields != '' or self.config.fields_from_example != '', (
'Either fields or fields_from_example must be specified.'
)
self.tokenizer = tokenizer
def __call__(self, example, has_aux=False):
if has_aux:
example, *aux = example
else:
aux = tuple()
token_buffer = []
loss_mask_buffer = []
if self.config.add_bos_token:
token_buffer.append(self.tokenizer.bos_token_id)
loss_mask_buffer.append(0.0)
if self.config.fields_from_example != '':
fields = example[self.config.fields_from_example].split(',')
else:
fields = self.config.fields.split(',')
for i, field in enumerate(fields):
if field.startswith('[') and field.endswith(']'):
# No loss for this field.
field = field[1:-1]
mask = 0.0
else:
mask = 1.0
if field.startswith('<|') and field.endswith('|>'):
# Special tokens.
field = field[2:-2]
if field == 'bos':
token_buffer.append(self.tokenizer.bos_token_id)
elif field == 'eos':
token_buffer.append(self.tokenizer.eos_token_id)
else:
# Token ID specified directly.
token_buffer.append(int(field))
loss_mask_buffer.append(mask)
elif field.startswith('{') and field.endswith('}'):
field = field[1:-1]
# Base64 encoded raw tokens.
tokens = np.frombuffer(
base64.b64decode(example[field]),
dtype=self.config.base64_token_dtype
).tolist()
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
else:
subfields = field.split('+')
text = self.config.subfield_separator.join(
[example[subfield] for subfield in subfields]
)
if i == 0:
text = self.config.prepend_text + text
tokens = self.tokenizer.encode(text)
token_buffer.extend(tokens)
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
if self.config.add_eos_token:
token_buffer.append(self.tokenizer.eos_token_id)
loss_mask_buffer.append(1.0)
return token_buffer, loss_mask_buffer, *aux
class HuggingfaceDataset(object):
""" Huggingface dataset, where the dataset is loaded using the huggingface
datasets.load_dataset() function.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = 'c4'
config.name = 'en'
config.split = 'train'
config.streaming = False
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.tokens_count_at_start = 0
config.batch_token_dtype = 'i4'
config.reset_dataset_loc = False
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
self.config = self.get_default_config(config)
name = self.config.name if self.config.name != '' else None
split = self.config.split if self.config.split != '' else None
self._tokenizer = tokenizer
self._text_processor = text_processor
self._dataset = load_from_disk(
self.config.path
)[split]
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
self._eval_dataset = eval_dataset
self._train_epochs = 0
self._dataset_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
self._index = 0
self.reset_dataset_loc = self.config.reset_dataset_loc
def __iter__(self):
if not self._eval_dataset and self._train_epochs > 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
chunk_size = self.config.batch_size * self.config.seq_length
while True:
token_buffer = []
loss_mask_buffer = []
if not self._eval_dataset and self._train_epochs > 0:
self._dataset.set_epoch(self._train_epochs)
for index, example in enumerate(self._dataset):
self._index = index
if not self._eval_dataset and self._dataset_loc > index:
continue
tokens, loss_masks = self.text_processor(example)
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
metrics = {
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'epoch': self._train_epochs,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
if self._eval_dataset:
break
else:
if self._train_epochs == 0:
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
self._dataset_loc = 0
self._train_epochs += 1
def get_state_dict(self):
return dict(
config=self.config,
dataset_loc=self._index,
total_tokens=self._total_tokens,
epochs=self._train_epochs,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
self._train_epochs = state_dict.get('epochs', 0)
if self.reset_dataset_loc:
self._dataset_loc = 0
self._train_epochs = 0
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def dataset(self):
return self._dataset
@property
def vocab_size(self):
return len(self._tokenizer)
class JsonDataset(object):
""" JSON dataset, where each line of the data file contains a JSON
dictionary with text fields.
"""
@staticmethod
def get_default_config(updates=None):
config = ConfigDict()
config.path = ''
config.seq_length = 1024
config.batch_size = 8
config.always_start_with_bos = False
config.start_seek_loc = 0
config.example_index_at_start = 0
config.tokens_count_at_start = 0
config.tokenizer_processes = 1
config.tokenizer_parallel_chunk_size = 32
config.tokenizer_parallel_batch_size = 1024
config.throughput_average_window_size = 200
if updates is not None:
config.update(ConfigDict(updates).copy_and_resolve_references())
return config
def __init__(self, config, tokenizer, text_processor):
self.config = self.get_default_config(config)
assert self.config.path != ''
self._tokenizer = tokenizer
self._text_processor = text_processor
self._index = self.config.example_index_at_start
self._file_loc = self.config.start_seek_loc
self._total_tokens = self.config.tokens_count_at_start
def parse_json(self, line):
if not line or line == '\n':
return None
try:
data = json.loads(line)
except json.decoder.JSONDecodeError:
print(f'Error parsing json line:\n{line}')
return None
return data
def json_iterator(self):
with mlxu.open_file(self.config.path, 'r') as fin:
fin.seek(self._file_loc)
while True:
line = fin.readline()
self._file_loc = fin.tell()
if not line: # Reached EOF
self._index = 0
fin.seek(0)
continue
data = self.parse_json(line)
if data is not None:
# JSON parsing succeeded
yield data, self._file_loc, self._index
self._index += 1
def batched(self, iterator, batch_size):
batch = []
for example in iterator:
batch.append(example)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def parallel_example_iterator(self):
if self.config.tokenizer_processes == 1:
for example, loc, index in self.json_iterator():
yield self.text_processor((example, loc, index), has_aux=True)
else:
process_pool = Pool(self.config.tokenizer_processes)
batched_iterator = self.batched(
self.json_iterator(), self.config.tokenizer_parallel_batch_size
)
with process_pool as pool:
map_fn = partial(self.text_processor, has_aux=True)
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
while True:
current_batch = next_batch
next_batch = pool.map_async(
map_fn, next(batched_iterator),
chunksize=self.config.tokenizer_parallel_chunk_size
)
for example in current_batch.get():
yield example
def __iter__(self):
chunk_size = self.config.batch_size * self.config.seq_length
token_buffer = []
loss_mask_buffer = []
last_time = 0.0
step_times = []
start_time = time.time()
start_tokens = self._total_tokens
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
token_buffer.extend(tokens)
loss_mask_buffer.extend(loss_masks)
while len(token_buffer) > chunk_size + 1:
self._total_tokens += chunk_size
step_times.append(time.time() - last_time)
last_time = time.time()
if len(step_times) > self.config.throughput_average_window_size:
step_times = step_times[-self.config.throughput_average_window_size:]
average_throughput = chunk_size / np.mean(step_times)
accumulated_throughput = (
(self._total_tokens - start_tokens) / (time.time() - start_time)
)
metrics = {
'dataset_file_loc': loc,
'dataset_example_index': index,
'dataset_total_tokens': self._total_tokens,
'dataset_accumulated_tps': accumulated_throughput,
'dataset_average_tps': average_throughput,
}
batch = {
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
self.config.batch_size, -1
),
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
self.config.batch_size, -1
),
}
if self.config.always_start_with_bos:
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
yield batch, metrics
token_buffer = token_buffer[chunk_size:]
loss_mask_buffer = loss_mask_buffer[chunk_size:]
def get_state_dict(self):
return dict(
config=self.config,
index=self._index,
file_loc=self._file_loc,
total_tokens=self._total_tokens,
)
def load_state_dict(self, state_dict):
if 'config' in state_dict:
self.config.update(ConfigDict(state_dict['config']))
self._index = state_dict.get('index', self.config.example_index_at_start)
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
@property
def seq_length(self):
return self.config.seq_length
@property
def tokenizer(self):
return self._tokenizer
@property
def text_processor(self):
return self._text_processor
@property
def vocab_size(self):
return len(self.tokenizer)