python medalpaca/train.py \ | |
--model PATH_TO_LLAMA_WEIGHTS \ | |
--data_path medical_meadow_small.json \ | |
--output_dir 'output' \ | |
--train_in_8bit True \ | |
--bf16 True \ | |
--tf32 False \ | |
--fp16 False \ | |
--global_batch_size 128 \ | |
--per_device_batch_size 8 |