3v324v23's picture
Update weights to checkpoint 140000
e557baa
raw
history blame
1.16 kB
#!/bin/bash
export CUDA_VISIBLE_DEVICES="1"
MODEL="flax-community/t5-base-dutch"
OUTPUT="./output"
TRAIN="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_train.json"
VAL="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_val.json"
TEST="/home/yeb/Developer/data/cnnuxsum/cnnuxsum_test.json"
mkdir -p "${OUTPUT}"
python ./run_summarization_flax.py \
--model_name_or_path "${MODEL}" \
--learning_rate "5e-4" \
--warmup_steps 500 \
--do_train \
--do_predict \
--do_eval \
--train_file "${TRAIN}" \
--validation_file "${VAL}" \
--test_file "${TEST}" \
--max_train_samples 1366592 \
--max_eval_samples 32 \
--max_predict_samples 8 \
--text_column "complete_text" \
--summary_column "summary_text" \
--max_source_length 1024 \
--max_target_length 142 \
--output_dir "${OUTPUT}" \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--overwrite_output_dir \
--num_train_epochs="1" \
--logging_steps="100" \
--save_steps="20000" \
--eval_steps="5000" \
--num_beams 4 \
--prediction_debug \
--predict_with_generate
# --source_prefix "summarize: " \