yhavinga's picture
Add pytorch model
522b344
raw history blame
No virus
1.45 kB
#!/bin/bash
export CUDA_VISIBLE_DEVICES=1
MODEL="flax-community/t5-base-dutch"
OUTPUT="./output"
TRAIN="/home/yeb/cnnuxsum/cnnuxsum_train.json"
VAL="/home/yeb/cnnuxsum/cnnuxsum_val.json"
TEST="/home/yeb/cnnuxsum/cnnuxsum_test.json"
mkdir -p "${OUTPUT}"
python ./run_summarization_flax.py \
--model_name_or_path "${MODEL}" \
--learning_rate "5e-4" \
--warmup_steps 500 \
--do_train \
--train_file "${TRAIN}" \
--validation_file "${VAL}" \
--test_file "${TEST}" \
--max_train_samples 640000 \
--max_eval_samples 512 \
--max_predict_samples 64 \
--text_column "complete_text" \
--summary_column "summary_text" \
--source_prefix "summarize: " \
--max_source_length 1024 \
--max_target_length 142 \
--output_dir "${OUTPUT}" \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=2 \
--overwrite_output_dir \
--num_train_epochs="1" \
--logging_steps="50" \
--save_steps="2000" \
--eval_steps="25000000" \
--num_beams 4
# \
# --do_predict
# --do_eval \
# \
# --prediction_debug \
# --predict_with_generate
# --source_prefix "summarize: " \
# --lr_scheduler_type="constant" \
# --task "summarization" \
# --early_stopping "true" \
# --length_penalty "2.0" \
# --max_length 300 \
# --min_length 75 \
# --no_repeat_ngram_size 3 \
# --num_beams 4 \
# --prefix "summarize: " \