Elron commited on
Commit
d79bb48
1 Parent(s): cbd0905

Upload loaders.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. loaders.py +97 -27
loaders.py CHANGED
@@ -1,6 +1,8 @@
 
1
  import itertools
2
- import logging
3
  import os
 
 
4
  from tempfile import TemporaryDirectory
5
  from typing import Dict, Mapping, Optional, Sequence, Union
6
 
@@ -8,11 +10,14 @@ import pandas as pd
8
  from datasets import load_dataset as hf_load_dataset
9
  from tqdm import tqdm
10
 
 
11
  from .operator import SourceOperator
12
  from .stream import MultiStream, Stream
13
 
 
14
  try:
15
  import ibm_boto3
 
16
  # from ibm_botocore.client import ClientError
17
 
18
  ibm_boto3_available = True
@@ -40,31 +45,35 @@ class LoadHF(Loader):
40
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
41
  ] = None
42
  streaming: bool = True
43
- cached = False
44
 
45
  def process(self):
46
  try:
47
- dataset = hf_load_dataset(
48
- self.path,
49
- name=self.name,
50
- data_dir=self.data_dir,
51
- data_files=self.data_files,
52
- streaming=self.streaming,
53
- split=self.split,
54
- )
 
 
55
  if self.split is not None:
56
  dataset = {self.split: dataset}
57
  except (
58
  NotImplementedError
59
  ): # streaming is not supported for zipped files so we load without streaming
60
- dataset = hf_load_dataset(
61
- self.path,
62
- name=self.name,
63
- data_dir=self.data_dir,
64
- data_files=self.data_files,
65
- streaming=False,
66
- split=self.split,
67
- )
 
 
 
68
  if self.split is None:
69
  for split in dataset.keys():
70
  dataset[split] = dataset[split].to_iterable_dataset()
@@ -92,16 +101,55 @@ class LoadCSV(Loader):
92
  )
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  class LoadFromIBMCloud(Loader):
96
  endpoint_url_env: str
97
  aws_access_key_id_env: str
98
  aws_secret_access_key_env: str
99
  bucket_name: str
100
  data_dir: str = None
101
- data_files: Sequence[str]
 
 
 
 
 
 
102
 
103
  def _download_from_cos(self, cos, bucket_name, item_name, local_file):
104
- logging.info(f"Downloading {item_name} from {bucket_name} COS")
105
  try:
106
  response = cos.Object(bucket_name, item_name).get()
107
  size = response["ContentLength"]
@@ -120,7 +168,7 @@ class LoadFromIBMCloud(Loader):
120
  for line in first_lines:
121
  downloaded_file.write(line)
122
  downloaded_file.write(b"\n")
123
- logging.info(
124
  f"\nDownload successful limited to {self.loader_limit} lines"
125
  )
126
  return
@@ -134,7 +182,7 @@ class LoadFromIBMCloud(Loader):
134
  cos.Bucket(bucket_name).download_file(
135
  item_name, local_file, Callback=upload_progress
136
  )
137
- logging.info("\nDownload Successful")
138
  except Exception as e:
139
  raise Exception(
140
  f"Unabled to download {item_name} in {bucket_name}", e
@@ -145,6 +193,11 @@ class LoadFromIBMCloud(Loader):
145
  self.endpoint_url = os.getenv(self.endpoint_url_env)
146
  self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
147
  self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)
 
 
 
 
 
148
 
149
  def verify(self):
150
  super().verify()
@@ -166,9 +219,20 @@ class LoadFromIBMCloud(Loader):
166
  aws_secret_access_key=self.aws_secret_access_key,
167
  endpoint_url=self.endpoint_url,
168
  )
169
-
170
- with TemporaryDirectory() as temp_directory:
171
- for data_file in self.data_files:
 
 
 
 
 
 
 
 
 
 
 
172
  # Build object key based on parameters. Slash character is not
173
  # allowed to be part of object key in IBM COS.
174
  object_key = (
@@ -177,8 +241,14 @@ class LoadFromIBMCloud(Loader):
177
  else data_file
178
  )
179
  self._download_from_cos(
180
- cos, self.bucket_name, object_key, temp_directory + "/" + data_file
181
  )
182
- dataset = hf_load_dataset(temp_directory, streaming=False)
 
 
 
 
 
 
183
 
184
  return MultiStream.from_iterables(dataset)
 
1
+ import importlib
2
  import itertools
 
3
  import os
4
+ import tempfile
5
+ from pathlib import Path
6
  from tempfile import TemporaryDirectory
7
  from typing import Dict, Mapping, Optional, Sequence, Union
8
 
 
10
  from datasets import load_dataset as hf_load_dataset
11
  from tqdm import tqdm
12
 
13
+ from .logging_utils import get_logger
14
  from .operator import SourceOperator
15
  from .stream import MultiStream, Stream
16
 
17
+ logger = get_logger()
18
  try:
19
  import ibm_boto3
20
+
21
  # from ibm_botocore.client import ClientError
22
 
23
  ibm_boto3_available = True
 
45
  Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
46
  ] = None
47
  streaming: bool = True
 
48
 
49
  def process(self):
50
  try:
51
+ with tempfile.TemporaryDirectory() as dir_to_be_deleted:
52
+ dataset = hf_load_dataset(
53
+ self.path,
54
+ name=self.name,
55
+ data_dir=self.data_dir,
56
+ data_files=self.data_files,
57
+ streaming=self.streaming,
58
+ cache_dir=None if self.streaming else dir_to_be_deleted,
59
+ split=self.split,
60
+ )
61
  if self.split is not None:
62
  dataset = {self.split: dataset}
63
  except (
64
  NotImplementedError
65
  ): # streaming is not supported for zipped files so we load without streaming
66
+ with tempfile.TemporaryDirectory() as dir_to_be_deleted:
67
+ dataset = hf_load_dataset(
68
+ self.path,
69
+ name=self.name,
70
+ data_dir=self.data_dir,
71
+ data_files=self.data_files,
72
+ streaming=False,
73
+ keep_in_memory=True,
74
+ cache_dir=dir_to_be_deleted,
75
+ split=self.split,
76
+ )
77
  if self.split is None:
78
  for split in dataset.keys():
79
  dataset[split] = dataset[split].to_iterable_dataset()
 
101
  )
102
 
103
 
104
+ class MissingKaggleCredentialsError(ValueError):
105
+ pass
106
+
107
+
108
+ # TODO write how to obtain kaggle credentials
109
+ class LoadFromKaggle(Loader):
110
+ url: str
111
+
112
+ def verify(self):
113
+ super().verify()
114
+ if importlib.util.find_spec("opendatasets") is None:
115
+ raise ImportError(
116
+ "Please install opendatasets in order to use the LoadFromKaggle loader (using `pip install opendatasets`) "
117
+ )
118
+ if not os.path.isfile("kaggle.json"):
119
+ raise MissingKaggleCredentialsError(
120
+ "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
121
+ )
122
+
123
+ def prepare(self):
124
+ super().prepare()
125
+ from opendatasets import download
126
+
127
+ self.downloader = download
128
+
129
+ def process(self):
130
+ with TemporaryDirectory() as temp_directory:
131
+ self.downloader(self.url, temp_directory)
132
+ dataset = hf_load_dataset(temp_directory, streaming=False)
133
+
134
+ return MultiStream.from_iterables(dataset)
135
+
136
+
137
  class LoadFromIBMCloud(Loader):
138
  endpoint_url_env: str
139
  aws_access_key_id_env: str
140
  aws_secret_access_key_env: str
141
  bucket_name: str
142
  data_dir: str = None
143
+
144
+ # Can be either:
145
+ # 1. a list of file names, the split of each file is determined by the file name pattern
146
+ # 2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
147
+ # 3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
148
+ data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
149
+ caching: bool = True
150
 
151
  def _download_from_cos(self, cos, bucket_name, item_name, local_file):
152
+ logger.info(f"Downloading {item_name} from {bucket_name} COS")
153
  try:
154
  response = cos.Object(bucket_name, item_name).get()
155
  size = response["ContentLength"]
 
168
  for line in first_lines:
169
  downloaded_file.write(line)
170
  downloaded_file.write(b"\n")
171
+ logger.info(
172
  f"\nDownload successful limited to {self.loader_limit} lines"
173
  )
174
  return
 
182
  cos.Bucket(bucket_name).download_file(
183
  item_name, local_file, Callback=upload_progress
184
  )
185
+ logger.info("\nDownload Successful")
186
  except Exception as e:
187
  raise Exception(
188
  f"Unabled to download {item_name} in {bucket_name}", e
 
193
  self.endpoint_url = os.getenv(self.endpoint_url_env)
194
  self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
195
  self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)
196
+ root_dir = os.getenv("UNITXT_IBM_COS_CACHE", None) or os.getcwd()
197
+ self.cache_dir = os.path.join(root_dir, "ibmcos_datasets")
198
+
199
+ if not os.path.exists(self.cache_dir):
200
+ Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
201
 
202
  def verify(self):
203
  super().verify()
 
219
  aws_secret_access_key=self.aws_secret_access_key,
220
  endpoint_url=self.endpoint_url,
221
  )
222
+ local_dir = os.path.join(self.cache_dir, self.bucket_name, self.data_dir)
223
+ if not os.path.exists(local_dir):
224
+ Path(local_dir).mkdir(parents=True, exist_ok=True)
225
+
226
+ if isinstance(self.data_files, Mapping):
227
+ data_files_names = list(self.data_files.values())
228
+ if not isinstance(data_files_names[0], str):
229
+ data_files_names = list(itertools.chain(*data_files_names))
230
+ else:
231
+ data_files_names = self.data_files
232
+
233
+ for data_file in data_files_names:
234
+ local_file = os.path.join(local_dir, data_file)
235
+ if not self.caching or not os.path.exists(local_file):
236
  # Build object key based on parameters. Slash character is not
237
  # allowed to be part of object key in IBM COS.
238
  object_key = (
 
241
  else data_file
242
  )
243
  self._download_from_cos(
244
+ cos, self.bucket_name, object_key, local_dir + "/" + data_file
245
  )
246
+
247
+ if isinstance(self.data_files, list):
248
+ dataset = hf_load_dataset(local_dir, streaming=False)
249
+ else:
250
+ dataset = hf_load_dataset(
251
+ local_dir, streaming=False, data_files=self.data_files
252
+ )
253
 
254
  return MultiStream.from_iterables(dataset)