Elron commited on
Commit
dbcdf0f
1 Parent(s): 3c5feb8

Upload stream.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stream.py +23 -12
stream.py CHANGED
@@ -1,4 +1,3 @@
1
- from copy import deepcopy
2
  from typing import Dict, Iterable
3
 
4
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
@@ -31,11 +30,11 @@ class Stream(Dataclass):
31
  """
32
  if self.caching:
33
  return Dataset.from_generator
34
- else:
35
- if self.copying:
36
- return CopyingReusableGenerator
37
- else:
38
- return ReusableGenerator
39
 
40
  def _get_stream(self):
41
  """Private method to get the stream based on the initiator function.
@@ -102,12 +101,20 @@ class MultiStream(dict):
102
 
103
  def to_dataset(self) -> DatasetDict:
104
  return DatasetDict(
105
- {key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key}) for key in self.keys()}
 
 
 
106
  )
107
 
108
  def to_iterable_dataset(self) -> IterableDatasetDict:
109
  return IterableDatasetDict(
110
- {key: IterableDataset.from_generator(self.get_generator, gen_kwargs={"key": key}) for key in self.keys()}
 
 
 
 
 
111
  )
112
 
113
  def __setitem__(self, key, value):
@@ -116,17 +123,19 @@ class MultiStream(dict):
116
  super().__setitem__(key, value)
117
 
118
  @classmethod
119
- def from_generators(cls, generators: Dict[str, ReusableGenerator], caching=False, copying=False):
 
 
120
  """Creates a MultiStream from a dictionary of ReusableGenerators.
121
 
122
  Args:
123
  generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
124
  caching (bool, optional): Whether the data should be cached or not. Defaults to False.
 
125
 
126
  Returns:
127
  MultiStream: A MultiStream object.
128
  """
129
-
130
  assert all(isinstance(v, ReusableGenerator) for v in generators.values())
131
  return cls(
132
  {
@@ -141,17 +150,19 @@ class MultiStream(dict):
141
  )
142
 
143
  @classmethod
144
- def from_iterables(cls, iterables: Dict[str, Iterable], caching=False, copying=False):
 
 
145
  """Creates a MultiStream from a dictionary of iterables.
146
 
147
  Args:
148
  iterables (Dict[str, Iterable]): A dictionary of iterables.
149
  caching (bool, optional): Whether the data should be cached or not. Defaults to False.
 
150
 
151
  Returns:
152
  MultiStream: A MultiStream object.
153
  """
154
-
155
  return cls(
156
  {
157
  key: Stream(
 
 
1
  from typing import Dict, Iterable
2
 
3
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
 
30
  """
31
  if self.caching:
32
  return Dataset.from_generator
33
+
34
+ if self.copying:
35
+ return CopyingReusableGenerator
36
+
37
+ return ReusableGenerator
38
 
39
  def _get_stream(self):
40
  """Private method to get the stream based on the initiator function.
 
101
 
102
  def to_dataset(self) -> DatasetDict:
103
  return DatasetDict(
104
+ {
105
+ key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key})
106
+ for key in self.keys()
107
+ }
108
  )
109
 
110
  def to_iterable_dataset(self) -> IterableDatasetDict:
111
  return IterableDatasetDict(
112
+ {
113
+ key: IterableDataset.from_generator(
114
+ self.get_generator, gen_kwargs={"key": key}
115
+ )
116
+ for key in self.keys()
117
+ }
118
  )
119
 
120
  def __setitem__(self, key, value):
 
123
  super().__setitem__(key, value)
124
 
125
  @classmethod
126
+ def from_generators(
127
+ cls, generators: Dict[str, ReusableGenerator], caching=False, copying=False
128
+ ):
129
  """Creates a MultiStream from a dictionary of ReusableGenerators.
130
 
131
  Args:
132
  generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
133
  caching (bool, optional): Whether the data should be cached or not. Defaults to False.
134
+ copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
135
 
136
  Returns:
137
  MultiStream: A MultiStream object.
138
  """
 
139
  assert all(isinstance(v, ReusableGenerator) for v in generators.values())
140
  return cls(
141
  {
 
150
  )
151
 
152
  @classmethod
153
+ def from_iterables(
154
+ cls, iterables: Dict[str, Iterable], caching=False, copying=False
155
+ ):
156
  """Creates a MultiStream from a dictionary of iterables.
157
 
158
  Args:
159
  iterables (Dict[str, Iterable]): A dictionary of iterables.
160
  caching (bool, optional): Whether the data should be cached or not. Defaults to False.
161
+ copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
162
 
163
  Returns:
164
  MultiStream: A MultiStream object.
165
  """
 
166
  return cls(
167
  {
168
  key: Stream(