Elron commited on
Commit
803d9a3
·
verified ·
1 Parent(s): c23528f

Upload stream.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stream.py +19 -11
stream.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict, Iterable
2
 
3
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
@@ -12,9 +13,9 @@ class Stream(Dataclass):
12
  This class provides methods for generating, caching, and manipulating streaming data.
13
 
14
  Attributes:
15
- generator (function): A generator function for streaming data.
16
- gen_kwargs (dict, optional): A dictionary of keyword arguments for the generator function.
17
- caching (bool): Whether the data is cached or not.
18
  """
19
 
20
  generator: callable
@@ -47,7 +48,7 @@ class Stream(Dataclass):
47
  def __iter__(self):
48
  return iter(self._get_stream())
49
 
50
- def peak(self):
51
  return next(iter(self))
52
 
53
  def take(self, n):
@@ -99,13 +100,20 @@ class MultiStream(dict):
99
  for stream in self.values():
100
  stream.copying = copying
101
 
102
- def to_dataset(self) -> DatasetDict:
103
- return DatasetDict(
104
- {
105
- key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key})
106
- for key in self.keys()
107
- }
108
- )
 
 
 
 
 
 
 
109
 
110
  def to_iterable_dataset(self) -> IterableDatasetDict:
111
  return IterableDatasetDict(
 
1
+ import tempfile
2
  from typing import Dict, Iterable
3
 
4
  from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
 
13
  This class provides methods for generating, caching, and manipulating streaming data.
14
 
15
  Attributes:
16
+ generator (function): A generator function for streaming data. :no-index:
17
+ gen_kwargs (dict, optional): A dictionary of keyword arguments for the generator function. :no-index:
18
+ caching (bool): Whether the data is cached or not. :no-index:
19
  """
20
 
21
  generator: callable
 
48
  def __iter__(self):
49
  return iter(self._get_stream())
50
 
51
+ def peek(self):
52
  return next(iter(self))
53
 
54
  def take(self, n):
 
100
  for stream in self.values():
101
  stream.copying = copying
102
 
103
+ def to_dataset(self, disable_cache=True, cache_dir=None) -> DatasetDict:
104
+ with tempfile.TemporaryDirectory() as dir_to_be_deleted:
105
+ cache_dir = dir_to_be_deleted if disable_cache else cache_dir
106
+ return DatasetDict(
107
+ {
108
+ key: Dataset.from_generator(
109
+ self.get_generator,
110
+ keep_in_memory=disable_cache,
111
+ cache_dir=cache_dir,
112
+ gen_kwargs={"key": key},
113
+ )
114
+ for key in self.keys()
115
+ }
116
+ )
117
 
118
  def to_iterable_dataset(self) -> IterableDatasetDict:
119
  return IterableDatasetDict(