boris commited on
Commit
5ee6e60
1 Parent(s): 53dade7

feat(data): accept braceexpand notation

Browse files
Files changed (1) hide show
  1. dalle_mini/data.py +6 -0
dalle_mini/data.py CHANGED
@@ -6,6 +6,7 @@ import jax.numpy as jnp
6
  import numpy as np
7
  from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
 
9
 
10
  from .text import TextNormalizer
11
 
@@ -33,6 +34,11 @@ class Dataset:
33
  def __post_init__(self):
34
  # define data_files
35
  if self.train_file is not None or self.validation_file is not None:
 
 
 
 
 
36
  data_files = {
37
  "train": self.train_file,
38
  "validation": self.validation_file,
 
6
  import numpy as np
7
  from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
9
+ from braceexpand import braceexpand
10
 
11
  from .text import TextNormalizer
12
 
 
34
  def __post_init__(self):
35
  # define data_files
36
  if self.train_file is not None or self.validation_file is not None:
37
+ # accept braceexpand notation
38
+ for k in ["train_file", "validation_file"]:
39
+ f = getattr(self, k)
40
+ if isinstance(f, str):
41
+ setattr(self, k, list(braceexpand(f)))
42
  data_files = {
43
  "train": self.train_file,
44
  "validation": self.validation_file,