Spaces:
Running
Running
| from dataclasses import dataclass, field | |
| from datasets import load_dataset, Dataset | |
| from functools import partial | |
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| from flax.training.common_utils import shard | |
| from .text import TextNormalizer | |
| class Dataset: | |
| dataset_repo_or_path: str | |
| train_file: str = None | |
| validation_file: str = None | |
| dataset_type: str = "dataset" | |
| streaming: bool = True | |
| use_auth_token: bool = False | |
| text_column: str = "caption" | |
| encoding_column: str = "encoding" | |
| max_source_length: int = 128 | |
| max_train_samples: int = None | |
| max_eval_samples: int = None | |
| preprocessing_num_workers: int = None | |
| overwrite_cache: bool = False | |
| do_train: bool = False | |
| do_eval: bool = True | |
| seed_dataset: int = None | |
| train_dataset: Dataset = field(init=False) | |
| eval_dataset: Dataset = field(init=False) | |
| rng_dataset: jnp.ndarray = field(init=False) | |
| def __post_init__(self): | |
| # define data_files | |
| if self.train_file is not None or self.validation_file is not None: | |
| data_files = { | |
| "train": self.train_file, | |
| "validation": self.validation_file, | |
| } | |
| else: | |
| data_files = None | |
| # load dataset | |
| dataset = load_dataset( | |
| self.dataset_repo_or_path, | |
| data_files=data_files, | |
| streaming=self.streaming, | |
| use_auth_token=self.use_auth_token, | |
| ) | |
| if self.do_train: | |
| if "train" not in dataset: | |
| raise ValueError("Training requires a training dataset") | |
| self.train_dataset = dataset["train"] | |
| if self.max_train_samples is not None: | |
| self.train_dataset = ( | |
| self.train_dataset.take(self.max_train_samples) | |
| if self.streaming | |
| else self.train_dataset.select(range(self.max_train_samples)) | |
| ) | |
| if self.do_eval: | |
| if "validation" not in dataset: | |
| raise ValueError("Evaluating requires a validation dataset") | |
| self.eval_dataset = dataset["validation"] | |
| if self.max_eval_samples is not None: | |
| self.eval_dataset = ( | |
| self.eval_dataset.take(self.max_eval_samples) | |
| if self.streaming | |
| else self.eval_dataset.select(range(self.max_eval_samples)) | |
| ) | |
| def preprocess(self, tokenizer, decoder_start_token_id, normalize_text): | |
| if self.streaming: | |
| # we need to shuffle early in streaming mode | |
| if hasattr(self, "train_dataset"): | |
| self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset) | |
| else: | |
| # prepare rng for later shuffling | |
| if self.seed_dataset is None: | |
| self.seed_dataset = np.random.get_state()[1][0] | |
| self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) | |
| # normalize text | |
| if normalize_text: | |
| text_normalizer = TextNormalizer() | |
| partial_normalize_function = partial( | |
| normalize_function, | |
| text_column=self.text_column, | |
| text_normalizer=text_normalizer, | |
| ) | |
| for ds in ["train_dataset", "eval_dataset"]: | |
| if hasattr(self, ds): | |
| setattr( | |
| self, | |
| ds, | |
| ( | |
| getattr(self, ds).map(partial_normalize_function) | |
| if self.streaming | |
| else getattr(self, ds).map( | |
| partial_normalize_function, | |
| num_proc=self.preprocessing_num_workers, | |
| load_from_cache_file=not self.overwrite_cache, | |
| desc="Normalizing datasets", | |
| ) | |
| ), | |
| ) | |
| # preprocess | |
| partial_preprocess_function = partial( | |
| preprocess_function, | |
| tokenizer=tokenizer, | |
| text_column=self.text_column, | |
| encoding_column=self.encoding_column, | |
| max_source_length=self.max_source_length, | |
| decoder_start_token_id=decoder_start_token_id, | |
| ) | |
| for ds in ["train_dataset", "eval_dataset"]: | |
| if hasattr(self, ds): | |
| setattr( | |
| self, | |
| ds, | |
| ( | |
| getattr(self, ds).map( | |
| partial_preprocess_function, | |
| batched=True, | |
| ) | |
| if self.streaming | |
| else getattr(self, ds).map( | |
| partial_preprocess_function, | |
| batched=True, | |
| remove_columns=getattr(ds, "column_names"), | |
| num_proc=self.preprocessing_num_workers, | |
| load_from_cache_file=not self.overwrite_cache, | |
| desc="Preprocessing datasets", | |
| ) | |
| ), | |
| ) | |
| def dataloader(self, split, batch_size, epoch=None): | |
| def _dataloader_datasets_non_streaming( | |
| dataset: Dataset, | |
| batch_size: int, | |
| rng: jax.random.PRNGKey = None, | |
| ): | |
| """ | |
| Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. | |
| Shuffle batches if `shuffle` is `True`. | |
| """ | |
| steps_per_epoch = len(dataset) // batch_size | |
| if rng is not None: | |
| batch_idx = jax.random.permutation(rng, len(dataset)) | |
| else: | |
| batch_idx = jnp.arange(len(dataset)) | |
| batch_idx = batch_idx[ | |
| : steps_per_epoch * batch_size | |
| ] # Skip incomplete batch. | |
| batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | |
| for idx in batch_idx: | |
| batch = dataset[idx] | |
| batch = {k: jnp.array(v) for k, v in batch.items()} | |
| batch = shard(batch) | |
| yield batch | |
| def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int): | |
| keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] | |
| batch = {k: [] for k in keys} | |
| for item in dataset: | |
| for k, v in item.items(): | |
| batch[k].append(v) | |
| if len(batch[keys[0]]) == batch_size: | |
| batch = {k: jnp.array(v) for k, v in batch.items()} | |
| batch = shard(batch) | |
| yield batch | |
| batch = {k: [] for k in keys} | |
| if split == "train": | |
| ds = self.train_dataset | |
| elif split == "eval": | |
| ds = self.eval_dataset | |
| else: | |
| raise ValueError(f'split must be "train" or "eval", got {split}') | |
| if self.streaming: | |
| if split == "train": | |
| ds.set_epoch(epoch) | |
| return _dataloader_datasets_streaming(ds, batch_size) | |
| else: | |
| if split == "train": | |
| self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) | |
| return _dataloader_datasets_non_streaming(ds, batch_size, input_rng) | |
| def length(self): | |
| len_train_dataset, len_eval_dataset = None, None | |
| if self.streaming: | |
| # we don't know the length, let's just assume max_samples if defined | |
| if self.max_train_samples is not None: | |
| len_train_dataset = self.max_train_samples | |
| if self.max_eval_samples is not None: | |
| len_eval_dataset = self.max_eval_samples | |
| else: | |
| len_train_dataset = ( | |
| len(self.train_dataset) if hasattr(self, "train_dataset") else None | |
| ) | |
| len_eval_dataset = ( | |
| len(self.eval_dataset) if hasattr(self, "eval_dataset") else None | |
| ) | |
| return len_train_dataset, len_eval_dataset | |
| def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): | |
| """ | |
| Shift input ids one token to the right. | |
| """ | |
| shifted_input_ids = np.zeros(input_ids.shape) | |
| shifted_input_ids[:, 1:] = input_ids[:, :-1] | |
| shifted_input_ids[:, 0] = decoder_start_token_id | |
| return shifted_input_ids | |
| def normalize_function(example, text_column, text_normalizer): | |
| example[text_column] = text_normalizer(example[text_column]) | |
| return example | |
| def preprocess_function( | |
| examples, | |
| tokenizer, | |
| text_column, | |
| encoding_column, | |
| max_source_length, | |
| decoder_start_token_id, | |
| ): | |
| inputs = examples[text_column] | |
| # Setting padding="max_length" as we need fixed length inputs for jitted functions | |
| model_inputs = tokenizer( | |
| inputs, | |
| max_length=max_source_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="np", | |
| ) | |
| # set up targets | |
| # Note: labels correspond to our target indices | |
| # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token) | |
| labels = examples[encoding_column] | |
| labels = np.asarray(labels) | |
| # We need the labels, in addition to the decoder_input_ids, for the compute_loss function | |
| model_inputs["labels"] = labels | |
| # In our case, this prepends the bos token and removes the last one | |
| decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) | |
| model_inputs["decoder_input_ids"] = decoder_input_ids | |
| return model_inputs | |