Elron commited on
Commit
382d4f4
1 Parent(s): be9158e

Upload loaders.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loaders.py +99 -15
loaders.py CHANGED
@@ -34,6 +34,7 @@ import pandas as pd
34
  from datasets import load_dataset as hf_load_dataset
35
  from tqdm import tqdm
36
 
 
37
  from .logging_utils import get_logger
38
  from .operator import SourceOperator
39
  from .settings_utils import get_settings
@@ -45,8 +46,6 @@ settings = get_settings()
45
  try:
46
  import ibm_boto3
47
 
48
- # from ibm_botocore.client import ClientError
49
-
50
  ibm_boto3_available = True
51
  except ImportError:
52
  ibm_boto3_available = False
@@ -62,6 +61,27 @@ class Loader(SourceOperator):
62
  loader_limit: int = None
63
  streaming: bool = False
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  class LoadHF(Loader):
67
  path: str
@@ -71,10 +91,11 @@ class LoadHF(Loader):
71
  data_files: Optional[
72
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
73
  ] = None
74
- streaming: bool = False
 
75
 
76
- def process(self):
77
- try:
78
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
79
  try:
80
  dataset = hf_load_dataset(
@@ -92,11 +113,18 @@ class LoadHF(Loader):
92
  raise ValueError(
93
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment vairable: UNITXT_ALLOW_UNVERIFIED_CODE."
94
  ) from e
 
95
  if self.split is not None:
96
  dataset = {self.split: dataset}
97
- except (
98
- NotImplementedError
99
- ): # streaming is not supported for zipped files so we load without streaming
 
 
 
 
 
 
100
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
101
  try:
102
  dataset = hf_load_dataset(
@@ -121,17 +149,73 @@ class LoadHF(Loader):
121
  else:
122
  dataset = {self.split: dataset}
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  return MultiStream.from_iterables(dataset)
125
 
126
 
127
  class LoadCSV(Loader):
128
  files: Dict[str, str]
129
  chunksize: int = 1000
 
 
 
130
 
131
  def stream_csv(self, file):
132
- for chunk in pd.read_csv(file, chunksize=self.chunksize):
133
- for _index, row in chunk.iterrows():
 
 
 
 
 
 
 
 
 
134
  yield row.to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  def process(self):
137
  if self.streaming:
@@ -144,7 +228,7 @@ class LoadCSV(Loader):
144
 
145
  return MultiStream(
146
  {
147
- name: pd.read_csv(file).to_dict("records")
148
  for name, file in self.files.items()
149
  }
150
  )
@@ -211,17 +295,17 @@ class LoadFromIBMCloud(Loader):
211
  f"Unabled to access {item_name} in {bucket_name} in COS", e
212
  ) from e
213
 
214
- if self.loader_limit is not None:
215
  if item_name.endswith(".jsonl"):
216
  first_lines = list(
217
- itertools.islice(body.iter_lines(), self.loader_limit)
218
  )
219
  with open(local_file, "wb") as downloaded_file:
220
  for line in first_lines:
221
  downloaded_file.write(line)
222
  downloaded_file.write(b"\n")
223
  logger.info(
224
- f"\nDownload successful limited to {self.loader_limit} lines"
225
  )
226
  return
227
 
@@ -277,7 +361,7 @@ class LoadFromIBMCloud(Loader):
277
  self.cache_dir,
278
  self.bucket_name,
279
  self.data_dir,
280
- f"loader_limit_{self.loader_limit}",
281
  )
282
  if not os.path.exists(local_dir):
283
  Path(local_dir).mkdir(parents=True, exist_ok=True)
 
34
  from datasets import load_dataset as hf_load_dataset
35
  from tqdm import tqdm
36
 
37
+ from .dataclass import InternalField
38
  from .logging_utils import get_logger
39
  from .operator import SourceOperator
40
  from .settings_utils import get_settings
 
46
  try:
47
  import ibm_boto3
48
 
 
 
49
  ibm_boto3_available = True
50
  except ImportError:
51
  ibm_boto3_available = False
 
61
  loader_limit: int = None
62
  streaming: bool = False
63
 
64
+ def get_limit(self):
65
+ if settings.global_loader_limit is not None and self.loader_limit is not None:
66
+ return min(int(settings.global_loader_limit), self.loader_limit)
67
+ if settings.global_loader_limit is not None:
68
+ return int(settings.global_loader_limit)
69
+ return self.loader_limit
70
+
71
+ def get_limiter(self):
72
+ if settings.global_loader_limit is not None and self.loader_limit is not None:
73
+ if int(settings.global_loader_limit) > self.loader_limit:
74
+ return f"{self.__class__.__name__}.loader_limit"
75
+ return "unitxt.settings.global_loader_limit"
76
+ if settings.global_loader_limit is not None:
77
+ return "unitxt.settings.global_loader_limit"
78
+ return f"{self.__class__.__name__}.loader_limit"
79
+
80
+ def log_limited_loading(self):
81
+ logger.info(
82
+ f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
83
+ )
84
+
85
 
86
  class LoadHF(Loader):
87
  path: str
 
91
  data_files: Optional[
92
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
93
  ] = None
94
+ streaming: bool = True
95
+ _cache: dict = InternalField(default=None)
96
 
97
+ def stream_dataset(self):
98
+ if self._cache is None:
99
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
100
  try:
101
  dataset = hf_load_dataset(
 
113
  raise ValueError(
114
  f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment vairable: UNITXT_ALLOW_UNVERIFIED_CODE."
115
  ) from e
116
+
117
  if self.split is not None:
118
  dataset = {self.split: dataset}
119
+
120
+ self._cache = dataset
121
+ else:
122
+ dataset = self._cache
123
+
124
+ return dataset
125
+
126
+ def load_dataset(self):
127
+ if self._cache is None:
128
  with tempfile.TemporaryDirectory() as dir_to_be_deleted:
129
  try:
130
  dataset = hf_load_dataset(
 
149
  else:
150
  dataset = {self.split: dataset}
151
 
152
+ self._cache = dataset
153
+ else:
154
+ dataset = self._cache
155
+
156
+ return dataset
157
+
158
+ def split_limited_load(self, split_name):
159
+ yield from itertools.islice(self._cache[split_name], self.get_limit())
160
+
161
+ def limited_load(self):
162
+ self.log_limited_loading()
163
+ return MultiStream(
164
+ {
165
+ name: Stream(
166
+ generator=self.split_limited_load, gen_kwargs={"split_name": name}
167
+ )
168
+ for name in self._cache.keys()
169
+ }
170
+ )
171
+
172
+ def process(self):
173
+ try:
174
+ dataset = self.stream_dataset()
175
+ except (
176
+ NotImplementedError
177
+ ): # streaming is not supported for zipped files so we load without streaming
178
+ dataset = self.load_dataset()
179
+
180
+ if self.get_limit() is not None:
181
+ return self.limited_load()
182
+
183
  return MultiStream.from_iterables(dataset)
184
 
185
 
186
  class LoadCSV(Loader):
187
  files: Dict[str, str]
188
  chunksize: int = 1000
189
+ _cache = InternalField(default_factory=dict)
190
+ loader_limit: int = None
191
+ streaming: bool = True
192
 
193
  def stream_csv(self, file):
194
+ if self.get_limit() is not None:
195
+ self.log_limited_loading()
196
+ chunksize = min(self.get_limit(), self.chunksize)
197
+ else:
198
+ chunksize = self.chunksize
199
+
200
+ row_count = 0
201
+ for chunk in pd.read_csv(file, chunksize=chunksize):
202
+ for _, row in chunk.iterrows():
203
+ if self.get_limit() is not None and row_count >= self.get_limit():
204
+ return
205
  yield row.to_dict()
206
+ row_count += 1
207
+
208
+ def load_csv(self, file):
209
+ if file not in self._cache:
210
+ if self.get_limit() is not None:
211
+ self.log_limited_loading()
212
+ self._cache[file] = pd.read_csv(file, nrows=self.get_limit()).to_dict(
213
+ "records"
214
+ )
215
+ else:
216
+ self._cache[file] = pd.read_csv(file).to_dict("records")
217
+
218
+ yield from self._cache[file]
219
 
220
  def process(self):
221
  if self.streaming:
 
228
 
229
  return MultiStream(
230
  {
231
+ name: Stream(generator=self.load_csv, gen_kwargs={"file": file})
232
  for name, file in self.files.items()
233
  }
234
  )
 
295
  f"Unabled to access {item_name} in {bucket_name} in COS", e
296
  ) from e
297
 
298
+ if self.get_limit() is not None:
299
  if item_name.endswith(".jsonl"):
300
  first_lines = list(
301
+ itertools.islice(body.iter_lines(), self.get_limit())
302
  )
303
  with open(local_file, "wb") as downloaded_file:
304
  for line in first_lines:
305
  downloaded_file.write(line)
306
  downloaded_file.write(b"\n")
307
  logger.info(
308
+ f"\nDownload successful limited to {self.get_limit()} lines"
309
  )
310
  return
311
 
 
361
  self.cache_dir,
362
  self.bucket_name,
363
  self.data_dir,
364
+ f"loader_limit_{self.get_limit()}",
365
  )
366
  if not os.path.exists(local_dir):
367
  Path(local_dir).mkdir(parents=True, exist_ok=True)