lianghsun's picture
First commit
87cbd1b
# streamlit_app.py
# Compute-Optimal LLM Training Estimator (Chinchilla-style)
# ---------------------------------------------------------
# Usage: `streamlit run streamlit_app.py`
# This tool helps estimate total FLOPs, steps, wall-clock time, and rough cost
# for LLM pretraining given model parameters, token budget, and hardware.
import math
import streamlit as st
st.set_page_config(page_title="LLM Compute Estimator", page_icon="🧮", layout="centered")
st.title("🧮 LLM Compute-Optimal Estimator")
st.caption("Estimate total FLOPs, wall-clock time, steps, and cost for pretraining — with a Chinchilla-style token rule.")
# --- Sidebar: assumptions ---
with st.sidebar:
st.logo('./static/logo_light.png')
st.header("Assumptions & Notes")
st.markdown(
"""
**Formulas**
- **Total FLOPs** ≈ `c * N_params * N_tokens`, with default **c = 6** (forward+backward+optimizer overhead).
- **Compute-optimal tokens** (rule-of-thumb): `N_tokens ≈ k * N_params`, default **k = 20**.
- **Effective compute** = `GPU_count * (peak TFLOPs × 1e12) * efficiency`.
**Disclaimers**
- This is a *back-of-the-envelope* estimator. Real training efficiency depends on data pipeline, parallelism strategy, sequence length, kernel fusion, optimizer, etc.
- Preset TFLOPs are **approximate** and depend on precision (FP8/BF16), sparsity, clocks, and vendor kernels.
"""
)
# --- 1) Model size & tokens ---
st.subheader("1) Model & Token Budget")
col1, col2, col3 = st.columns([1.2, 1, 1])
with col1:
model_params_b = st.number_input("Model size (Billions of parameters)", min_value=0.05, value=4.0, step=0.5, format="%.2f")
with col2:
c_overhead = st.number_input("c (FLOPs constant)", min_value=4.0, value=6.0, step=0.5)
with col3:
k_tokens_per_param = st.number_input("k (tokens per param for compute-optimal)", min_value=5.0, value=20.0, step=1.0)
use_compute_optimal = st.toggle("Use compute‑optimal tokens (k × params)", value=True)
if use_compute_optimal:
tokens_b = model_params_b * k_tokens_per_param
st.info(f"Compute‑optimal token budget ≈ **{tokens_b:,.2f} B** (k = {k_tokens_per_param:g})")
else:
tokens_b = st.number_input("Token budget (Billions)", min_value=1.0, value=80.0, step=5.0, format="%.2f")
# --- 2) Hardware (moved before batch to define gpu_count first) ---
st.subheader("2) Hardware")
col6, col7 = st.columns(2)
with col6:
gpu_preset = st.selectbox(
"GPU preset (approx peak TFLOPs per GPU)",
(
"Custom",
"A100 80GB BF16 ≈ 312",
"H100 SXM BF16 ≈ 989",
"B200 (FP8-ish) ≈ 20000",
),
index=0,
help="Values are back-of-the-envelope. Choose 'Custom' to enter your own.",
)
preset_map = {
"A100 80GB BF16 ≈ 312": 312.0,
"H100 SXM BF16 ≈ 989": 989.0,
"B200 (FP8-ish) ≈ 20000": 20000.0,
}
with col7:
if gpu_preset == "Custom":
peak_tflops = st.number_input("Peak TFLOPs per GPU (approx)", min_value=10.0, value=20000.0, step=100.0)
else:
peak_tflops = preset_map[gpu_preset]
st.number_input("Peak TFLOPs per GPU (approx)", value=peak_tflops, disabled=True)
col8, col9, col10 = st.columns(3)
with col8:
gpu_count = st.number_input("GPU count", min_value=1, value=8, step=1)
with col9:
efficiency = st.slider("Training efficiency (MFU, %)", min_value=10, max_value=95, value=50, step=1)
with col10:
price_per_gpu_hour = st.number_input("Price per GPU·hour (USD)", min_value=0.0, value=25.0, step=1.0)
# --- 3) Batch & Sequence Settings (tokens_per_step computed from gpu_count) ---
st.subheader("3) Batch & Sequence Settings")
col4, col5 = st.columns(2)
with col4:
micro_batch = st.number_input("Micro batch size per GPU", min_value=1, value=4, step=1, help="Number of sequences per GPU per optimizer step.")
with col5:
seq_len = st.number_input("Sequence length (tokens)", min_value=128, value=2048, step=128)
tokens_per_step = micro_batch * seq_len * gpu_count
st.info(f"Tokens per optimization step ≈ {tokens_per_step:,} (with {gpu_count} GPUs)")
# --- Compute ---
N_params = model_params_b * 1e9
N_tokens = tokens_b * 1e9
c = c_overhead
# Total FLOPs (scalar)
flops_total = c * N_params * N_tokens # in FLOPs
# Effective machine compute per second
effective_flops_per_s = gpu_count * (peak_tflops * 1e12) * (efficiency / 100.0)
# Time estimate
seconds = flops_total / effective_flops_per_s if effective_flops_per_s > 0 else float('inf')
hours = seconds / 3600
days = hours / 24
# Steps
steps = N_tokens / tokens_per_step if tokens_per_step > 0 else float('inf')
# Throughput
throughput_tokens_per_s = N_tokens / seconds if seconds > 0 else float('inf')
# Cost
cost = price_per_gpu_hour * gpu_count * hours
# --- Display ---
st.divider()
st.subheader("Results")
colA, colB = st.columns(2)
with colA:
st.metric("Total FLOPs", f"{flops_total:,.2e} FLOPs")
st.metric("Effective compute", f"{effective_flops_per_s:,.2e} FLOPs/s")
st.metric("Steps (est)", f"{0 if steps == float('inf') else steps:,.0f}")
with colB:
st.metric("Wall‑clock time", f"{hours:,.1f} h (~{days:,.2f} d)")
st.metric("Throughput", f"{0 if throughput_tokens_per_s == float('inf') else throughput_tokens_per_s:,.0f} tok/s")
st.metric("Projected cost", f"${0 if cost == float('inf') else cost:,.0f}")
st.divider()
st.markdown(
f"""
**Summary**
- Params: **{model_params_b:,.2f}B** · Tokens: **{tokens_b:,.2f}B** (compute‑optimal: {use_compute_optimal})
- Constant **c = {c:g}** → Total ≈ **{flops_total:,.2e} FLOPs**.
- Hardware: **{gpu_count}× GPU**, peak **{peak_tflops:g} TFLOPs/GPU**, MFU **{efficiency}%** → Effective ≈ **{effective_flops_per_s:,.2e} FLOPs/s**.
- Time ≈ **{hours:,.1f} hours** (≈ {days:,.2f} days). Steps ≈ **{0 if steps == float('inf') else steps:,.0f}** (@ {tokens_per_step:,} tok/step).
- Rough cost ≈ **${0 if cost == float('inf') else cost:,.0f}** (@ ${price_per_gpu_hour:g}/GPU·h).
"""
)
with st.expander("What is the Chinchilla rule? Is it 1 epoch?"):
st.markdown(
"""
**Chinchilla scaling** is a *compute‑optimal* rule of thumb: for a fixed compute budget, scale
the **training tokens** roughly in proportion to the **model parameters** (commonly ~20× tokens per parameter).
It is **not** about training for exactly one epoch. In web‑scale pretraining, datasets are often sampled with
replacement or mixed; you might see data multiple times or less than once. The rule speaks to the *total number
of tokens* the model should process for best use of compute, not to dataset passes.
"""
)
st.success("Ready. Tweak inputs on the left to explore different scenarios.")