|
import torch |
|
import pytest |
|
import sys |
|
|
|
|
|
def copy_mlp(llama_mlp, orig_llama_mlp) -> None: |
|
orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight) |
|
orig_llama_mlp.w3.weight.copy_(llama_mlp.c_fc2.weight) |
|
orig_llama_mlp.w2.weight.copy_(llama_mlp.c_proj.weight) |
|
|
|
|
|
def copy_attention(llama_attn, orig_llama_attn) -> None: |
|
n_embd = llama_attn.c_attn.weight.shape[1] |
|
orig_llama_attn.wq.weight.copy_(llama_attn.c_attn.weight[:n_embd]) |
|
orig_llama_attn.wk.weight.copy_(llama_attn.c_attn.weight[n_embd:-n_embd]) |
|
orig_llama_attn.wv.weight.copy_(llama_attn.c_attn.weight[-n_embd:]) |
|
orig_llama_attn.wo.weight.copy_(llama_attn.c_proj.weight) |
|
|
|
|
|
def copy_block(llama_block, orig_llama_block) -> None: |
|
orig_llama_block.attention_norm.weight.copy_(llama_block.rms_1.scale) |
|
copy_attention(llama_block.attn, orig_llama_block.attention) |
|
orig_llama_block.ffn_norm.weight.copy_(llama_block.rms_2.scale) |
|
copy_mlp(llama_block.mlp, orig_llama_block.feed_forward) |
|
|
|
|
|
def copy_weights(llama_model, orig_llama_model) -> None: |
|
orig_llama_model.tok_embeddings.weight.copy_(llama_model.transformer.wte.weight) |
|
for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): |
|
copy_block(llama_block, orig_llama_block) |
|
orig_llama_model.norm.weight.copy_(llama_model.transformer.ln_f.scale) |
|
orig_llama_model.output.weight.copy_(llama_model.lm_head.weight) |
|
|
|
|
|
@torch.no_grad() |
|
@pytest.mark.parametrize("kv_cache", (False, True)) |
|
def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: |
|
block_size = 64 |
|
vocab_size = 32000 |
|
n_layer = 16 |
|
n_head = 16 |
|
n_embd = 32 |
|
batch_size = 3 |
|
|
|
llama_config = lit_llama.LLaMAConfig( |
|
block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd |
|
) |
|
orig_llama_config = orig_llama.ModelArgs( |
|
dim=n_embd, |
|
n_layers=n_layer, |
|
n_heads=n_head, |
|
vocab_size=vocab_size, |
|
norm_eps=1e-5, |
|
max_seq_len=block_size, |
|
max_batch_size=batch_size, |
|
) |
|
|
|
seq_len = orig_llama_config.max_seq_len |
|
token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) |
|
|
|
llama_model = lit_llama.LLaMA(llama_config) |
|
llama_model.apply(llama_model._init_weights) |
|
orig_llama_model = orig_llama.Transformer(orig_llama_config) |
|
|
|
copy_weights(llama_model, orig_llama_model) |
|
|
|
orig_llama_embed = orig_llama_model.tok_embeddings(token_sample) |
|
llama_embed = llama_model.transformer.wte(token_sample) |
|
assert torch.allclose(orig_llama_embed, llama_embed) |
|
|
|
llama_rope = llama_model.build_rope_cache(token_sample) |
|
llama_mask = llama_model.build_mask_cache(token_sample) |
|
orig_llama_mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) |
|
orig_llama_mask = torch.triu(orig_llama_mask, diagonal=1) |
|
if kv_cache: |
|
orig_llama_block_out = orig_llama_model.layers[0]( |
|
orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask |
|
) |
|
theirs_k_cache = orig_llama_model.layers[0].attention.cache_k |
|
theirs_v_cache = orig_llama_model.layers[0].attention.cache_v |
|
head_size = n_embd // n_head |
|
kv_cache_shape = (batch_size, n_head, block_size, head_size) |
|
ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape) |
|
(llama_block_out, ours_kv_cache) = llama_model.transformer.h[0]( |
|
llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache |
|
) |
|
ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3) |
|
ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3) |
|
torch.testing.assert_close(ours_k_cache, theirs_k_cache) |
|
torch.testing.assert_close(ours_v_cache, theirs_v_cache) |
|
else: |
|
orig_llama_block_out = orig_llama_model.layers[0]( |
|
orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask |
|
) |
|
(llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len) |
|
assert torch.allclose(orig_llama_block_out, llama_block_out) |
|
|
|
expected = orig_llama_model(token_sample, 0) |
|
out = llama_model(token_sample) |
|
assert torch.allclose(out, expected) |
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") |
|
@torch.no_grad() |
|
def test_bfloat16_llama_init(lit_llama, orig_llama) -> None: |
|
from lit_llama.utils import EmptyInitOnDevice |
|
|
|
block_size = 64 |
|
vocab_size = 32000 |
|
n_layer = 16 |
|
n_head = 16 |
|
n_embd = 32 |
|
|
|
llama_config = lit_llama.LLaMAConfig( |
|
block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd |
|
) |
|
llama_model = lit_llama.LLaMA(llama_config) |
|
llama_model.apply(llama_model._init_weights) |
|
|
|
batch_size = 3 |
|
|
|
token_sample = torch.randint(0, vocab_size, size=(batch_size, block_size), dtype=torch.int64) |
|
|
|
expected = llama_model(token_sample) |
|
|
|
with EmptyInitOnDevice(device="cuda", dtype=torch.bfloat16): |
|
llama_model2 = lit_llama.LLaMA(llama_config) |
|
llama_model2.load_state_dict(llama_model.state_dict(keep_vars=True)) |
|
|
|
out = llama_model2(token_sample.cuda()).float().cpu() |
|
torch.testing.assert_close(out, expected, atol=5e-3, rtol=1e-3) |
|
|
|
|
|
def copy_adapter_weights(llama_model, orig_llama_model) -> None: |
|
|
|
for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): |
|
if hasattr(llama_block.attn, "gating_factor"): |
|
llama_block.attn.gating_factor.copy_(orig_llama_block.attention.gate) |
|
|
|
|
|
orig_adapter_wte = orig_llama_model.adapter_query.weight.reshape( |
|
orig_llama_model.params.adapter_layer, orig_llama_model.params.adapter_len, orig_llama_model.params.dim |
|
) |
|
|
|
|
|
index = 0 |
|
for llama_block in llama_model.transformer.h: |
|
if hasattr(llama_block.attn, "adapter_wte"): |
|
llama_block.attn.adapter_wte.weight.copy_(orig_adapter_wte[index]) |
|
index += 1 |
|
|
|
|
|
def enable_gate(model): |
|
for name, param in model.named_parameters(): |
|
if "gating_factor" in name or "gate" in name: |
|
param.fill_(1) |
|
|
|
|
|
@torch.no_grad() |
|
def test_adapter_parity(orig_llama_adapter): |
|
"""Test parity between our implementation of LLaMA-Adapter and the reference code.""" |
|
import lit_llama.adapter as lit_llama |
|
|
|
orig_llama = orig_llama_adapter |
|
|
|
block_size = 32 |
|
vocab_size = 100 |
|
n_layer = 2 |
|
n_head = 4 |
|
n_embd = 16 |
|
adapter_prompt_length: int = 10 |
|
adapter_start_layer: int = 0 |
|
|
|
llama_config = lit_llama.LLaMAConfig( |
|
block_size=block_size, |
|
vocab_size=vocab_size, |
|
n_layer=n_layer, |
|
n_head=n_head, |
|
n_embd=n_embd, |
|
adapter_prompt_length=adapter_prompt_length, |
|
adapter_start_layer=adapter_start_layer, |
|
) |
|
orig_llama_config = orig_llama.ModelArgs( |
|
dim=n_embd, |
|
n_layers=n_layer, |
|
n_heads=n_head, |
|
vocab_size=vocab_size, |
|
norm_eps=1e-5, |
|
max_seq_len=block_size, |
|
adapter_len=adapter_prompt_length, |
|
adapter_layer=(n_layer - adapter_start_layer), |
|
) |
|
|
|
batch_size = 3 |
|
token_sample = torch.randint( |
|
0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64 |
|
) |
|
|
|
llama_model = lit_llama.LLaMA(llama_config) |
|
llama_model.apply(llama_model._init_weights) |
|
orig_llama_model = orig_llama.Transformer(orig_llama_config) |
|
|
|
copy_weights(llama_model, orig_llama_model) |
|
copy_adapter_weights(llama_model, orig_llama_model) |
|
|
|
|
|
|
|
enable_gate(llama_model) |
|
enable_gate(orig_llama_model) |
|
|
|
expected = orig_llama_model(token_sample, 0) |
|
out = llama_model(token_sample) |
|
assert torch.allclose(out, expected) |
|
|
|
|
|
@pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform") |
|
def test_model_compile(lit_llama): |
|
llama_config = lit_llama.LLaMAConfig(block_size=8, vocab_size=8, n_layer=2, n_head=2, n_embd=4) |
|
model = lit_llama.LLaMA(llama_config) |
|
model.apply(model._init_weights) |
|
|
|
model = torch.compile(model) |
|
|
|
sample = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) |
|
for _ in range(3): |
|
_ = model(sample) |
|
|