cahya commited on
Commit
b25d188
1 Parent(s): e5e9f73

reorder "import jax" to avoid the stuck on importing it, change the dataset

Browse files
Files changed (2) hide show
  1. run_clm_flax.py +1 -1
  2. run_finetuning.sh +3 -3
run_clm_flax.py CHANGED
@@ -30,11 +30,11 @@ from dataclasses import dataclass, field
30
  from pathlib import Path
31
  from typing import Callable, Optional
32
 
 
33
  import datasets
34
  from datasets import Dataset, load_dataset
35
  from tqdm import tqdm
36
 
37
- import jax
38
  import jax.numpy as jnp
39
  import optax
40
  import transformers
 
30
  from pathlib import Path
31
  from typing import Callable, Optional
32
 
33
+ import jax
34
  import datasets
35
  from datasets import Dataset, load_dataset
36
  from tqdm import tqdm
37
 
 
38
  import jax.numpy as jnp
39
  import optax
40
  import transformers
run_finetuning.sh CHANGED
@@ -5,18 +5,18 @@ export WANDB_LOG_MODEL="true"
5
 
6
  ./run_clm_flax.py \
7
  --model_name_or_path="./flax_model.msgpack" \
8
- --output_dir="${MODEL_DIR}/finetuning2" \
9
  --model_type="gpt2" \
10
  --config_name="${MODEL_DIR}" \
11
  --tokenizer_name="${MODEL_DIR}" \
12
  --dataset_name="./text_collection" \
13
  --dataset_config_name="text_collection" \
14
- --dataset_data_dir="/dataset/fiction/story_all" \
15
  --do_train --do_eval \
16
  --block_size="512" \
17
  --per_device_train_batch_size="8" \
18
  --per_device_eval_batch_size="8" \
19
- --learning_rate="0.0000001" --warmup_steps="1000" \
20
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
21
  --overwrite_output_dir \
22
  --num_train_epochs="20" \
 
5
 
6
  ./run_clm_flax.py \
7
  --model_name_or_path="./flax_model.msgpack" \
8
+ --output_dir="${MODEL_DIR}/finetuning" \
9
  --model_type="gpt2" \
10
  --config_name="${MODEL_DIR}" \
11
  --tokenizer_name="${MODEL_DIR}" \
12
  --dataset_name="./text_collection" \
13
  --dataset_config_name="text_collection" \
14
+ --dataset_data_dir="/media/storage/datasets/storial/books_txt" \
15
  --do_train --do_eval \
16
  --block_size="512" \
17
  --per_device_train_batch_size="8" \
18
  --per_device_eval_batch_size="8" \
19
+ --learning_rate="0.00005" --warmup_steps="1000" \
20
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
21
  --overwrite_output_dir \
22
  --num_train_epochs="20" \