--- title: "Optimizing LLM Performance Using Triton" format: revealjs: theme: dark transition: slide slide-number: true author: "Matej Sirovatka" date: "2025-02-22" --- ## `whoami` - My name is Matej - I'm a Master's student at Brno University of Technology - I'm currently working on distributed training at Hugging Face 🤗 ## `What is Triton?` - open-source programming language for GPU kernels by Open AI - Designed for AI/ML workloads - Simplifies GPU programming compared to CUDA ![](media/optim_scale.png){.center fig-align="center"} ## `Why Optimize with Triton?` - Simple yet effective - Less headache than CUDA - GPUs go `brrrrrrr` 🚀 - Feel cool when your kernel is faster than PyTorch 😎 ## `Example Problem: KL Divergence` - commonly used in LLMs for knowledge distillation - for probability distributions $P$ and $Q$, the Kullback-Leibler divergence is defined as: $$ D_{KL}(P \| Q) = \sum_{i} P_i \log\left(\frac{P_i}{Q_i}\right) $$ ```python import torch from torch.nn.functional import kl_div def kl_div_torch(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor: return kl_div(p, q) ``` ## `How about Triton?` ```python import triton import triton.language as tl @triton.jit def kl_div_triton( p_ptr, q_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements p = tl.load(p_ptr + offsets, mask=mask) q = tl.load(q_ptr + offsets, mask=mask) output = p * (tl.log(p) - tl.log(q)) tl.store(output_ptr + offsets, output, mask=mask) ``` ## `How to integrate with PyTorch?` - How to use our custom kernel with PyTorch autograd? ```python import torch class KlDiv(torch.autograd.Function): @staticmethod def forward(ctx, p, q): ctx.save_for_backward(q) output = torch.empty_like(p) grid = (len(p) + 512 - 1) // 512 kl_div_triton[grid](p, q, output, len(p), BLOCK_SIZE=512) return output @staticmethod def backward(ctx, grad_output): q = ctx.saved_tensors[0] # Calculate gradients (another triton kernel) return ... ``` ## `Some benchmarks` - A KL Divergence kernel that is currently used in [Liger Kernel](https://github.com/linkedin/liger-kernel) written by @me :::: {.columns} ::: {.column width="50%"} ![](media/kl_mem.png){.center fig-align="center"} ::: ::: {.column width="50%"} ![](media/kl_speed.png){.center fig-align="center"} ::: :::: ## `Do I have to write everything?` - TLDR: No - Many cool projects already using Triton - Better Integration with PyTorch and even Hugging Face 🤗 - Liger Kernel, Unsloth AI, etc. :::: {.columns} ::: {.column width="50%"} ![](media/unsloth.png){.center fig-align="center"} ::: ::: {.column width="50%"} ![](media/liger.png){.center fig-align="center"} ::: :::: ## `So how can I use this in my LLM? 🚀` - Liger Kernel is a great example, providing examples of how to integrate with Hugging Face 🤗 Trainer ```diff - from transformers import AutoModelForCausalLM + from liger_kernel.transformers import AutoLigerKernelForCausalLM model_path = "meta-llama/Meta-Llama-3-8B-Instruct" - model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoLigerKernelForCausalLM.from_pretrained(model_path) # training/inference logic... ``` ## `Key Optimization Techniques adapted by Liger Kernel` - Kernel Fusion - Domain-specific optimizations - Memory Access Patterns - Preemptive memory freeing ## `Aaand some more benchmarks 🚀` - Saving memory is key to run bigger batch size on smaller GPUs :::: {.columns} ::: {.column width="50%"} ![](media/PMA.png){fig-align="center"} ::: ::: {.column width="50%"} ![](media/PMR.png){fig-align="center"} ::: :::: ## `Last benchmark I promise...` - But is it faster? Yes, it is! ![](media/TPS.png){fig-align="center" height=50% width=50%} :::: {.columns} ::: {.column width="60%"} *Attention is all you need, so I thank you for yours!* 🤗 ::: ::: {.column width="40%"} ![](media/qr.png){height=25% width=25% fig-align="center"} ::: ::::