cahya commited on
Commit
c15b9df
1 Parent(s): 8ee977d

udpated the model and script to load local data

Browse files
Files changed (3) hide show
  1. pytorch_model.bin +1 -1
  2. run_clm_flax.py +6 -1
  3. run_pretraining.sh +3 -2
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:83e1f5e0a435dafa8cc6d555fd5441209ec15ff7c93cffc1079502ffa0b84b93
3
  size 1444576537
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f67e392707d4b269ea616f717bb7451e79d8cc0235449e990209b12bb74aad45
3
  size 1444576537
run_clm_flax.py CHANGED
@@ -112,6 +112,9 @@ class DataTrainingArguments:
112
  dataset_config_name: Optional[str] = field(
113
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
114
  )
 
 
 
115
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
116
  validation_file: Optional[str] = field(
117
  default=None,
@@ -296,19 +299,21 @@ def main():
296
  if data_args.dataset_name is not None:
297
  # Downloading and loading a dataset from the hub.
298
  dataset = load_dataset(
299
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
300
  )
301
 
302
  if "validation" not in dataset.keys():
303
  dataset["validation"] = load_dataset(
304
  data_args.dataset_name,
305
  data_args.dataset_config_name,
 
306
  split=f"train[:{data_args.validation_split_percentage}%]",
307
  cache_dir=model_args.cache_dir,
308
  )
309
  dataset["train"] = load_dataset(
310
  data_args.dataset_name,
311
  data_args.dataset_config_name,
 
312
  split=f"train[{data_args.validation_split_percentage}%:]",
313
  cache_dir=model_args.cache_dir,
314
  )
 
112
  dataset_config_name: Optional[str] = field(
113
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
114
  )
115
+ dataset_data_dir: Optional[str] = field(
116
+ default=None, metadata={"help": "The name of the data directory."}
117
+ )
118
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
119
  validation_file: Optional[str] = field(
120
  default=None,
 
299
  if data_args.dataset_name is not None:
300
  # Downloading and loading a dataset from the hub.
301
  dataset = load_dataset(
302
+ data_args.dataset_name, data_args.dataset_config_name, data_dir=data_args.dataset_data_dir, cache_dir=model_args.cache_dir, keep_in_memory=False
303
  )
304
 
305
  if "validation" not in dataset.keys():
306
  dataset["validation"] = load_dataset(
307
  data_args.dataset_name,
308
  data_args.dataset_config_name,
309
+ data_dir=data_args.dataset_data_dir,
310
  split=f"train[:{data_args.validation_split_percentage}%]",
311
  cache_dir=model_args.cache_dir,
312
  )
313
  dataset["train"] = load_dataset(
314
  data_args.dataset_name,
315
  data_args.dataset_config_name,
316
+ data_dir=data_args.dataset_data_dir,
317
  split=f"train[{data_args.validation_split_percentage}%:]",
318
  cache_dir=model_args.cache_dir,
319
  )
run_pretraining.sh CHANGED
@@ -9,8 +9,9 @@ export WANDB_LOG_MODEL="true"
9
  --model_type="gpt2" \
10
  --config_name="${MODEL_DIR}" \
11
  --tokenizer_name="${MODEL_DIR}" \
12
- --dataset_name="oscar" \
13
- --dataset_config_name="unshuffled_deduplicated_id" \
 
14
  --do_train --do_eval \
15
  --block_size="512" \
16
  --per_device_train_batch_size="24" \
 
9
  --model_type="gpt2" \
10
  --config_name="${MODEL_DIR}" \
11
  --tokenizer_name="${MODEL_DIR}" \
12
+ --dataset_name="./datasets/id_collection" \
13
+ --dataset_config_name="id_collection" \
14
+ --dataset_data_dir="/data/collection" \
15
  --do_train --do_eval \
16
  --block_size="512" \
17
  --per_device_train_batch_size="24" \