Spaces:
Runtime error
Runtime error
set -e | |
origin=CarperAI/trlx | |
branch=main | |
entity=null | |
only_hash=false | |
only_tiny=false | |
while [[ "$#" -gt 0 ]]; do | |
case $1 in | |
--origin) origin="$2"; shift ;; | |
--branch) branch="$2"; shift ;; | |
--public) entity='"CarperAI"' ;; | |
--only_hash) only_hash=true ;; | |
--only_tiny) only_tiny=true ;; | |
*) echo "Unknown parameter passed: $1"; exit 1 ;; | |
esac | |
shift | |
done | |
dir=`mktemp -d -p .` | |
if [ ! -d "$dir" ]; then | |
echo "Couldn't create a temporary directory, aborting" | |
exit 1 | |
fi | |
cd $dir | |
trap "rm -rf ../$dir" EXIT | |
git clone --depth 1 --single-branch -b $branch https://github.com/$origin . | |
hash=`find . -not \( -path ./.git -prune \) -not -name "*.md" -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -f1 -d" "` | |
git_hash=`git log --format=%h/%s/%as -n1` | |
if [ "$only_hash" = true ]; then | |
echo "$hash" | |
echo "$git_hash" | |
exit 0 | |
fi | |
python -m venv venv | |
. venv/bin/activate | |
python -m pip install pip --upgrade | |
pip install -r requirements.txt | |
pip install -e . | |
args='{"train": {"project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' | |
python examples/randomwalks/ilql_randomwalks.py "$args" | |
python examples/randomwalks/ppo_randomwalks.py "$args" | |
if [ "$only_tiny" = true ]; then | |
exit 0 | |
fi | |
rm -rf ../benchmark_logs && mkdir ../benchmark_logs | |
CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8880 examples/ppo_sentiments.py "$args" > ../benchmark_logs/ppo_sentiments.log 2>&1 & | |
CUDA_VISIBLE_DEVICES=1 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8881 examples/sft_sentiments.py "$args" > ../benchmark_logs/sft_sentiments.log 2>&1 & | |
CUDA_VISIBLE_DEVICES=2 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8882 examples/ilql_sentiments.py "$args" > ../benchmark_logs/ilql_sentiments.log 2>&1 & | |
CUDA_VISIBLE_DEVICES=3 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8883 examples/ppo_sentiments_t5.py "$args" > ../benchmark_logs/ppo_sentiments_t5.log 2>&1 & | |
wait | |
args='{"train": {"total_steps": 1500, "seq_length": 512, "project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' | |
CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py "$args" | |