import dataclasses import pprint import time from functools import partial import json import base64 from multiprocessing import Pool import h5py import mlxu from ml_collections.config_dict import config_dict from ml_collections import ConfigDict from tqdm import tqdm, trange import numpy as np from datasets import load_dataset, load_from_disk class DatasetFactory(object): """ Datset builder class. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.type = 'huggingface' config.text_processor = TextProcessor.get_default_config() config.huggingface_dataset = HuggingfaceDataset.get_default_config() config.json_dataset = JsonDataset.get_default_config() if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config @classmethod def load_dataset(cls, config, tokenizer, **kwargs): config = cls.get_default_config(config) text_processor = TextProcessor(config.text_processor, tokenizer) if config.type == 'huggingface': return HuggingfaceDataset( config.huggingface_dataset, tokenizer, text_processor, **kwargs ) elif config.type == 'json': return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs) else: raise ValueError(f'Unknown dataset type: {config.type}') def __init__(self): raise ValueError('DatasetFactory is a static class and should not be instantiated.') class TextProcessor(object): """ Example processor that converts a dictionary of texts into tokens. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.fields_from_example = '' config.fields = '' config.subfield_separator = ' ' config.add_bos_token = True config.add_eos_token = True config.prepend_text = '' config.base64_token_dtype = 'i4' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer): self.config = self.get_default_config(config) assert self.config.fields != '' or self.config.fields_from_example != '', ( 'Either fields or fields_from_example must be specified.' ) self.tokenizer = tokenizer def __call__(self, example, has_aux=False): if has_aux: example, *aux = example else: aux = tuple() token_buffer = [] loss_mask_buffer = [] if self.config.add_bos_token: token_buffer.append(self.tokenizer.bos_token_id) loss_mask_buffer.append(0.0) if self.config.fields_from_example != '': fields = example[self.config.fields_from_example].split(',') else: fields = self.config.fields.split(',') for i, field in enumerate(fields): if field.startswith('[') and field.endswith(']'): # No loss for this field. field = field[1:-1] mask = 0.0 else: mask = 1.0 if field.startswith('<|') and field.endswith('|>'): # Special tokens. field = field[2:-2] if field == 'bos': token_buffer.append(self.tokenizer.bos_token_id) elif field == 'eos': token_buffer.append(self.tokenizer.eos_token_id) else: # Token ID specified directly. token_buffer.append(int(field)) loss_mask_buffer.append(mask) elif field.startswith('{') and field.endswith('}'): field = field[1:-1] # Base64 encoded raw tokens. tokens = np.frombuffer( base64.b64decode(example[field]), dtype=self.config.base64_token_dtype ).tolist() token_buffer.extend(tokens) loss_mask_buffer.extend([mask for _ in range(len(tokens))]) else: subfields = field.split('+') text = self.config.subfield_separator.join( [example[subfield] for subfield in subfields] ) if i == 0: text = self.config.prepend_text + text tokens = self.tokenizer.encode(text) token_buffer.extend(tokens) loss_mask_buffer.extend([mask for _ in range(len(tokens))]) if self.config.add_eos_token: token_buffer.append(self.tokenizer.eos_token_id) loss_mask_buffer.append(1.0) return token_buffer, loss_mask_buffer, *aux class HuggingfaceDataset(object): """ Huggingface dataset, where the dataset is loaded using the huggingface datasets.load_dataset() function. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.path = 'c4' config.name = 'en' config.split = 'train' config.streaming = False config.seq_length = 1024 config.batch_size = 8 config.always_start_with_bos = False config.start_seek_loc = 0 config.tokens_count_at_start = 0 config.batch_token_dtype = 'i4' if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer, text_processor, eval_dataset=False): self.config = self.get_default_config(config) name = self.config.name if self.config.name != '' else None split = self.config.split if self.config.split != '' else None self._tokenizer = tokenizer self._text_processor = text_processor self._dataset = load_from_disk( self.config.path )[split] self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset)) self._eval_dataset = eval_dataset self._train_epochs = 0 self._dataset_loc = self.config.start_seek_loc self._total_tokens = self.config.tokens_count_at_start self._index = 0 def __iter__(self): if not self._eval_dataset and self._train_epochs > 0: self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000) chunk_size = self.config.batch_size * self.config.seq_length while True: token_buffer = [] loss_mask_buffer = [] if not self._eval_dataset and self._train_epochs > 0: self._dataset.set_epoch(self._train_epochs) for index, example in enumerate(self._dataset): self._index = index if not self._eval_dataset and self._dataset_loc > index: continue tokens, loss_masks = self.text_processor(example) token_buffer.extend(tokens) loss_mask_buffer.extend(loss_masks) while len(token_buffer) > chunk_size + 1: self._total_tokens += chunk_size metrics = { 'dataset_example_index': index, 'dataset_total_tokens': self._total_tokens, 'epoch': self._train_epochs, } batch = { 'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape( self.config.batch_size, -1 ), 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape( self.config.batch_size, -1 ), 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( self.config.batch_size, -1 ), } if self.config.always_start_with_bos: batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id yield batch, metrics token_buffer = token_buffer[chunk_size:] loss_mask_buffer = loss_mask_buffer[chunk_size:] if self._eval_dataset: break else: if self._train_epochs == 0: self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000) self._dataset_loc = 0 self._train_epochs += 1 def get_state_dict(self): return dict( config=self.config, dataset_loc=self._index, total_tokens=self._total_tokens, epochs=self._train_epochs, ) def load_state_dict(self, state_dict): if 'config' in state_dict: self.config.update(ConfigDict(state_dict['config'])) self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc) self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) self._train_epochs = state_dict.get('epochs', 0) @property def seq_length(self): return self.config.seq_length @property def tokenizer(self): return self._tokenizer @property def text_processor(self): return self._text_processor @property def dataset(self): return self._dataset @property def vocab_size(self): return len(self._tokenizer) class JsonDataset(object): """ JSON dataset, where each line of the data file contains a JSON dictionary with text fields. """ @staticmethod def get_default_config(updates=None): config = ConfigDict() config.path = '' config.seq_length = 1024 config.batch_size = 8 config.always_start_with_bos = False config.start_seek_loc = 0 config.example_index_at_start = 0 config.tokens_count_at_start = 0 config.tokenizer_processes = 1 config.tokenizer_parallel_chunk_size = 32 config.tokenizer_parallel_batch_size = 1024 config.throughput_average_window_size = 200 if updates is not None: config.update(ConfigDict(updates).copy_and_resolve_references()) return config def __init__(self, config, tokenizer, text_processor): self.config = self.get_default_config(config) assert self.config.path != '' self._tokenizer = tokenizer self._text_processor = text_processor self._index = self.config.example_index_at_start self._file_loc = self.config.start_seek_loc self._total_tokens = self.config.tokens_count_at_start def parse_json(self, line): if not line or line == '\n': return None try: data = json.loads(line) except json.decoder.JSONDecodeError: print(f'Error parsing json line:\n{line}') return None return data def json_iterator(self): with mlxu.open_file(self.config.path, 'r') as fin: fin.seek(self._file_loc) while True: line = fin.readline() self._file_loc = fin.tell() if not line: # Reached EOF self._index = 0 fin.seek(0) continue data = self.parse_json(line) if data is not None: # JSON parsing succeeded yield data, self._file_loc, self._index self._index += 1 def batched(self, iterator, batch_size): batch = [] for example in iterator: batch.append(example) if len(batch) == batch_size: yield batch batch = [] if len(batch) > 0: yield batch def parallel_example_iterator(self): if self.config.tokenizer_processes == 1: for example, loc, index in self.json_iterator(): yield self.text_processor((example, loc, index), has_aux=True) else: process_pool = Pool(self.config.tokenizer_processes) batched_iterator = self.batched( self.json_iterator(), self.config.tokenizer_parallel_batch_size ) with process_pool as pool: map_fn = partial(self.text_processor, has_aux=True) next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) while True: current_batch = next_batch next_batch = pool.map_async( map_fn, next(batched_iterator), chunksize=self.config.tokenizer_parallel_chunk_size ) for example in current_batch.get(): yield example def __iter__(self): chunk_size = self.config.batch_size * self.config.seq_length token_buffer = [] loss_mask_buffer = [] last_time = 0.0 step_times = [] start_time = time.time() start_tokens = self._total_tokens for tokens, loss_masks, loc, index in self.parallel_example_iterator(): token_buffer.extend(tokens) loss_mask_buffer.extend(loss_masks) while len(token_buffer) > chunk_size + 1: self._total_tokens += chunk_size step_times.append(time.time() - last_time) last_time = time.time() if len(step_times) > self.config.throughput_average_window_size: step_times = step_times[-self.config.throughput_average_window_size:] average_throughput = chunk_size / np.mean(step_times) accumulated_throughput = ( (self._total_tokens - start_tokens) / (time.time() - start_time) ) metrics = { 'dataset_file_loc': loc, 'dataset_example_index': index, 'dataset_total_tokens': self._total_tokens, 'dataset_accumulated_tps': accumulated_throughput, 'dataset_average_tps': average_throughput, } batch = { 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( self.config.batch_size, -1 ), 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( self.config.batch_size, -1 ), 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( self.config.batch_size, -1 ), } if self.config.always_start_with_bos: batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id yield batch, metrics token_buffer = token_buffer[chunk_size:] loss_mask_buffer = loss_mask_buffer[chunk_size:] def get_state_dict(self): return dict( config=self.config, index=self._index, file_loc=self._file_loc, total_tokens=self._total_tokens, ) def load_state_dict(self, state_dict): if 'config' in state_dict: self.config.update(ConfigDict(state_dict['config'])) self._index = state_dict.get('index', self.config.example_index_at_start) self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) @property def seq_length(self): return self.config.seq_length @property def tokenizer(self): return self._tokenizer @property def text_processor(self): return self._text_processor @property def vocab_size(self): return len(self.tokenizer)