TPU support
Lit-LLaMA used lightning.Fabric
under the hood, which itself supports TPUs (via PyTorch XLA).
The following commands will allow you to set up a Google Cloud
instance with a TPU v4 VM:
gcloud compute tpus tpu-vm create lit-llama --version=tpu-vm-v4-pt-2.0 --accelerator-type=v4-8 --zone=us-central2-b
gcloud compute tpus tpu-vm ssh lit-llama --zone=us-central2-b
Now that you are in the machine, let's clone the repository and install the dependencies
git clone https://github.com/Lightning-AI/lit-llama
cd lit-llama
pip install -r requirements.txt
By default, computations will run using the new (and experimental) PjRT runtime. Still, it's recommended that you set the following environment variables
export PJRT_DEVICE=TPU
export ALLOW_MULTIPLE_LIBTPU_LOAD=1
Note You can find an extensive guide on how to get set-up and all the available options here.
Since you created a new machine, you'll probably need to download the weights. You could scp them into the machine with gcloud compute tpus tpu-vm scp
or you can follow the steps described in our downloading guide.
Inference
Generation works out-of-the-box with TPUs:
python3 generate.py --prompt "Hello, my name is" --num_samples 3
This command will take take ~20s for the first generation time as XLA needs to compile the graph. You'll notice that afterwards, generation times drop to ~5s.
Finetuning
Coming soon.
Warning When you are done, remember to delete your instance
gcloud compute tpus tpu-vm delete lit-llama --zone=us-central2-b