Elron commited on
Commit
a254196
·
1 Parent(s): c6d1c21

Upload stream.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stream.py +185 -0
stream.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .generator_utils import ReusableGenerator
2
+
3
+ from typing import Iterable, Dict
4
+
5
+ from datasets import IterableDatasetDict, IterableDataset, DatasetDict, Dataset
6
+
7
+
8
+ class Stream:
9
+ """A class for handling streaming data in a customizable way.
10
+
11
+ This class provides methods for generating, caching, and manipulating streaming data.
12
+
13
+ Attributes:
14
+ generator (function): A generator function for streaming data.
15
+ gen_kwargs (dict, optional): A dictionary of keyword arguments for the generator function.
16
+ streaming (bool): Whether the data is streaming or not.
17
+ caching (bool): Whether the data is cached or not.
18
+ """
19
+
20
+ def __init__(self, generator, gen_kwargs=None, streaming=True, caching=False):
21
+ """Initializes the Stream with the provided parameters.
22
+
23
+ Args:
24
+ generator (function): A generator function for streaming data.
25
+ gen_kwargs (dict, optional): A dictionary of keyword arguments for the generator function. Defaults to None.
26
+ streaming (bool, optional): Whether the data is streaming or not. Defaults to True.
27
+ caching (bool, optional): Whether the data is cached or not. Defaults to False.
28
+ """
29
+
30
+ self.generator = generator
31
+ self.gen_kwargs = gen_kwargs if gen_kwargs is not None else {}
32
+ self.streaming = streaming
33
+ self.caching = caching
34
+
35
+ def _get_initator(self):
36
+ """Private method to get the correct initiator based on the streaming and caching attributes.
37
+
38
+ Returns:
39
+ function: The correct initiator function.
40
+ """
41
+ if self.streaming:
42
+ if self.caching:
43
+ return IterableDataset.from_generator
44
+ else:
45
+ return ReusableGenerator
46
+ else:
47
+ if self.caching:
48
+ return Dataset.from_generator
49
+ else:
50
+ raise ValueError("Cannot create non-streaming non-caching stream")
51
+
52
+ def _get_stream(self):
53
+ """Private method to get the stream based on the initiator function.
54
+
55
+ Returns:
56
+ object: The stream object.
57
+ """
58
+ return self._get_initator()(self.generator, gen_kwargs=self.gen_kwargs)
59
+
60
+ def set_caching(self, caching):
61
+ self.caching = caching
62
+
63
+ def set_streaming(self, streaming):
64
+ self.streaming = streaming
65
+
66
+ def __iter__(self):
67
+ return iter(self._get_stream())
68
+
69
+ def unwrap(self):
70
+ return self._get_stream()
71
+
72
+ def peak(self):
73
+ return next(iter(self))
74
+
75
+ def take(self, n):
76
+ for i, instance in enumerate(self):
77
+ if i >= n:
78
+ break
79
+ yield instance
80
+
81
+ def __repr__(self):
82
+ return f"{self.__class__.__name__}(generator={self.generator.__name__}, gen_kwargs={self.gen_kwargs}, streaming={self.streaming}, caching={self.caching})"
83
+
84
+
85
+ def is_stream(obj):
86
+ return isinstance(obj, IterableDataset) or isinstance(obj, Stream) or isinstance(obj, Dataset)
87
+
88
+
89
+ class MultiStream(dict):
90
+ """A class for handling multiple streams of data in a dictionary-like format.
91
+
92
+ This class extends dict and its values should be instances of the Stream class.
93
+
94
+ Attributes:
95
+ data (dict): A dictionary of Stream objects.
96
+ """
97
+
98
+ def __init__(self, data=None):
99
+ """Initializes the MultiStream with the provided data.
100
+
101
+ Args:
102
+ data (dict, optional): A dictionary of Stream objects. Defaults to None.
103
+
104
+ Raises:
105
+ AssertionError: If the values are not instances of Stream or keys are not strings.
106
+ """
107
+ for key, value in data.items():
108
+ isinstance(value, Stream), "MultiStream values must be Stream"
109
+ isinstance(key, str), "MultiStream keys must be strings"
110
+ super().__init__(data)
111
+
112
+ def get_generator(self, key):
113
+ """Gets a generator for a specified key.
114
+
115
+ Args:
116
+ key (str): The key for the generator.
117
+
118
+ Yields:
119
+ object: The next value in the stream.
120
+ """
121
+ yield from self[key]
122
+
123
+ def unwrap(self, cls):
124
+ return cls({key: value.unwrap() for key, value in self.items()})
125
+
126
+ def to_dataset(self) -> DatasetDict:
127
+ return DatasetDict(
128
+ {key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key}) for key in self.keys()}
129
+ )
130
+
131
+ def to_iterable_dataset(self) -> IterableDatasetDict:
132
+ return IterableDatasetDict(
133
+ {key: IterableDataset.from_generator(self.get_generator, gen_kwargs={"key": key}) for key in self.keys()}
134
+ )
135
+
136
+ def __setitem__(self, key, value):
137
+ assert isinstance(value, Stream), "StreamDict values must be Stream"
138
+ assert isinstance(key, str), "StreamDict keys must be strings"
139
+ super().__setitem__(key, value)
140
+
141
+ @classmethod
142
+ def from_generators(cls, generators: Dict[str, ReusableGenerator], streaming=True, caching=False):
143
+ """Creates a MultiStream from a dictionary of ReusableGenerators.
144
+
145
+ Args:
146
+ generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
147
+ streaming (bool, optional): Whether the data should be streaming or not. Defaults to True.
148
+ caching (bool, optional): Whether the data should be cached or not. Defaults to False.
149
+
150
+ Returns:
151
+ MultiStream: A MultiStream object.
152
+ """
153
+
154
+ assert all(isinstance(v, ReusableGenerator) for v in generators.values())
155
+ return cls(
156
+ {
157
+ key: Stream(
158
+ generator.get_generator(),
159
+ gen_kwargs=generator.get_gen_kwargs(),
160
+ streaming=streaming,
161
+ caching=caching,
162
+ )
163
+ for key, generator in generators.items()
164
+ }
165
+ )
166
+
167
+ @classmethod
168
+ def from_iterables(cls, iterables: Dict[str, Iterable], streaming=True, caching=False):
169
+ """Creates a MultiStream from a dictionary of iterables.
170
+
171
+ Args:
172
+ iterables (Dict[str, Iterable]): A dictionary of iterables.
173
+ streaming (bool, optional): Whether the data should be streaming or not. Defaults to True.
174
+ caching (bool, optional): Whether the data should be cached or not. Defaults to False.
175
+
176
+ Returns:
177
+ MultiStream: A MultiStream object.
178
+ """
179
+
180
+ return cls(
181
+ {
182
+ key: Stream(iterable.__iter__, gen_kwargs={}, streaming=streaming, caching=caching)
183
+ for key, iterable in iterables.items()
184
+ }
185
+ )