Feat: Add dataset loading from S3, GCS (#765)
Browse files* Feat: Add dataset loading from S3, GCS
* chore: update docs
* chore: add more info on cloud loading
- README.md +7 -1
- requirements.txt +6 -1
- src/axolotl/utils/data.py +97 -19
README.md
CHANGED
@@ -426,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|
426 |
- path: knowrohit07/know_sql
|
427 |
type: context_qa.load_v2
|
428 |
train_on_split: validation
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
```
|
430 |
|
431 |
- loading
|
@@ -520,7 +526,7 @@ float16: true
|
|
520 |
|
521 |
# A list of one or more datasets to finetune the model with
|
522 |
datasets:
|
523 |
-
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
524 |
- path: vicgalle/alpaca-gpt4
|
525 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
526 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
|
|
426 |
- path: knowrohit07/know_sql
|
427 |
type: context_qa.load_v2
|
428 |
train_on_split: validation
|
429 |
+
|
430 |
+
# loading from s3 or gcs
|
431 |
+
# s3 creds will be loaded from the system default and gcs only supports public access
|
432 |
+
dataset:
|
433 |
+
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
434 |
+
...
|
435 |
```
|
436 |
|
437 |
- loading
|
|
|
526 |
|
527 |
# A list of one or more datasets to finetune the model with
|
528 |
datasets:
|
529 |
+
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
530 |
- path: vicgalle/alpaca-gpt4
|
531 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
532 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
requirements.txt
CHANGED
@@ -11,7 +11,7 @@ deepspeed
|
|
11 |
addict
|
12 |
fire
|
13 |
PyYAML>=6.0
|
14 |
-
datasets
|
15 |
flash-attn>=2.3.0
|
16 |
sentencepiece
|
17 |
wandb
|
@@ -33,3 +33,8 @@ art
|
|
33 |
fschat==0.2.29
|
34 |
gradio
|
35 |
tensorboard
|
|
|
|
|
|
|
|
|
|
|
|
11 |
addict
|
12 |
fire
|
13 |
PyYAML>=6.0
|
14 |
+
datasets>=2.14.0
|
15 |
flash-attn>=2.3.0
|
16 |
sentencepiece
|
17 |
wandb
|
|
|
33 |
fschat==0.2.29
|
34 |
gradio
|
35 |
tensorboard
|
36 |
+
|
37 |
+
# remote filesystems
|
38 |
+
s3fs
|
39 |
+
gcsfs
|
40 |
+
# adlfs
|
src/axolotl/utils/data.py
CHANGED
@@ -170,30 +170,74 @@ def load_tokenized_prepared_datasets(
|
|
170 |
except (FileNotFoundError, ConnectionError):
|
171 |
pass
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# prefer local dataset, even if hub exists
|
174 |
local_path = Path(config_dataset.path)
|
175 |
if local_path.exists():
|
176 |
if local_path.is_dir():
|
177 |
-
|
178 |
-
ds = load_dataset(
|
179 |
-
config_dataset.path,
|
180 |
-
name=config_dataset.name,
|
181 |
-
data_files=config_dataset.data_files,
|
182 |
-
streaming=False,
|
183 |
-
split=None,
|
184 |
-
)
|
185 |
elif local_path.is_file():
|
186 |
-
ds_type =
|
187 |
-
|
188 |
-
ds_type = config_dataset.ds_type
|
189 |
-
elif ".parquet" in config_dataset.path:
|
190 |
-
ds_type = "parquet"
|
191 |
-
elif ".arrow" in config_dataset.path:
|
192 |
-
ds_type = "arrow"
|
193 |
-
elif ".csv" in config_dataset.path:
|
194 |
-
ds_type = "csv"
|
195 |
-
elif ".txt" in config_dataset.path:
|
196 |
-
ds_type = "text"
|
197 |
ds = load_dataset(
|
198 |
ds_type,
|
199 |
name=config_dataset.name,
|
@@ -213,6 +257,22 @@ def load_tokenized_prepared_datasets(
|
|
213 |
data_files=config_dataset.data_files,
|
214 |
token=use_auth_token,
|
215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
else:
|
217 |
if isinstance(config_dataset.data_files, str):
|
218 |
fp = hf_hub_download(
|
@@ -304,6 +364,24 @@ def load_tokenized_prepared_datasets(
|
|
304 |
return dataset, prompters
|
305 |
|
306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
def load_prepare_datasets(
|
308 |
tokenizer: PreTrainedTokenizerBase,
|
309 |
cfg,
|
|
|
170 |
except (FileNotFoundError, ConnectionError):
|
171 |
pass
|
172 |
|
173 |
+
ds_from_cloud = False
|
174 |
+
storage_options = {}
|
175 |
+
remote_file_system = None
|
176 |
+
if config_dataset.path.startswith("s3://"):
|
177 |
+
try:
|
178 |
+
import aiobotocore.session # type: ignore
|
179 |
+
import s3fs # type: ignore
|
180 |
+
except ImportError as exc:
|
181 |
+
raise ImportError(
|
182 |
+
"s3:// paths require aiobotocore and s3fs to be installed"
|
183 |
+
) from exc
|
184 |
+
|
185 |
+
# Takes credentials from ~/.aws/credentials for default profile
|
186 |
+
s3_session = aiobotocore.session.AioSession(profile="default")
|
187 |
+
storage_options = {"session": s3_session}
|
188 |
+
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
189 |
+
elif config_dataset.path.startswith(
|
190 |
+
"gs://"
|
191 |
+
) or config_dataset.path.startswith("gcs://"):
|
192 |
+
try:
|
193 |
+
import gcsfs # type: ignore
|
194 |
+
except ImportError as exc:
|
195 |
+
raise ImportError(
|
196 |
+
"gs:// or gcs:// paths require gcsfs to be installed"
|
197 |
+
) from exc
|
198 |
+
|
199 |
+
# gcsfs will use default credentials from the environment else anon
|
200 |
+
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
201 |
+
storage_options = {"token": None}
|
202 |
+
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
203 |
+
# TODO: Figure out how to get auth creds passed
|
204 |
+
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
205 |
+
# try:
|
206 |
+
# import adlfs
|
207 |
+
# except ImportError as exc:
|
208 |
+
# raise ImportError(
|
209 |
+
# "adl:// or abfs:// paths require adlfs to be installed"
|
210 |
+
# ) from exc
|
211 |
+
|
212 |
+
# # Gen 1
|
213 |
+
# storage_options = {
|
214 |
+
# "tenant_id": TENANT_ID,
|
215 |
+
# "client_id": CLIENT_ID,
|
216 |
+
# "client_secret": CLIENT_SECRET,
|
217 |
+
# }
|
218 |
+
# # Gen 2
|
219 |
+
# storage_options = {
|
220 |
+
# "account_name": ACCOUNT_NAME,
|
221 |
+
# "account_key": ACCOUNT_KEY,
|
222 |
+
# }
|
223 |
+
|
224 |
+
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
225 |
+
try:
|
226 |
+
if remote_file_system and remote_file_system.exists(
|
227 |
+
config_dataset.path
|
228 |
+
):
|
229 |
+
ds_from_cloud = True
|
230 |
+
except (FileNotFoundError, ConnectionError):
|
231 |
+
pass
|
232 |
+
|
233 |
# prefer local dataset, even if hub exists
|
234 |
local_path = Path(config_dataset.path)
|
235 |
if local_path.exists():
|
236 |
if local_path.is_dir():
|
237 |
+
ds = load_from_disk(config_dataset.path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
elif local_path.is_file():
|
239 |
+
ds_type = get_ds_type(config_dataset)
|
240 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
ds = load_dataset(
|
242 |
ds_type,
|
243 |
name=config_dataset.name,
|
|
|
257 |
data_files=config_dataset.data_files,
|
258 |
token=use_auth_token,
|
259 |
)
|
260 |
+
elif ds_from_cloud and remote_file_system:
|
261 |
+
if remote_file_system.isdir(config_dataset.path):
|
262 |
+
ds = load_from_disk(
|
263 |
+
config_dataset.path,
|
264 |
+
storage_options=storage_options,
|
265 |
+
)
|
266 |
+
elif remote_file_system.isfile(config_dataset.path):
|
267 |
+
ds_type = get_ds_type(config_dataset)
|
268 |
+
ds = load_dataset(
|
269 |
+
ds_type,
|
270 |
+
name=config_dataset.name,
|
271 |
+
data_files=config_dataset.path,
|
272 |
+
streaming=False,
|
273 |
+
split=None,
|
274 |
+
storage_options=storage_options,
|
275 |
+
)
|
276 |
else:
|
277 |
if isinstance(config_dataset.data_files, str):
|
278 |
fp = hf_hub_download(
|
|
|
364 |
return dataset, prompters
|
365 |
|
366 |
|
367 |
+
def get_ds_type(config_dataset: DictDefault):
|
368 |
+
"""
|
369 |
+
Get the dataset type from the path if it's not specified
|
370 |
+
"""
|
371 |
+
ds_type = "json"
|
372 |
+
if config_dataset.ds_type:
|
373 |
+
ds_type = config_dataset.ds_type
|
374 |
+
elif ".parquet" in config_dataset.path:
|
375 |
+
ds_type = "parquet"
|
376 |
+
elif ".arrow" in config_dataset.path:
|
377 |
+
ds_type = "arrow"
|
378 |
+
elif ".csv" in config_dataset.path:
|
379 |
+
ds_type = "csv"
|
380 |
+
elif ".txt" in config_dataset.path:
|
381 |
+
ds_type = "text"
|
382 |
+
return ds_type
|
383 |
+
|
384 |
+
|
385 |
def load_prepare_datasets(
|
386 |
tokenizer: PreTrainedTokenizerBase,
|
387 |
cfg,
|