import torch import pytest import sys def copy_mlp(llama_mlp, orig_llama_mlp) -> None: orig_llama_mlp.w1.weight.copy_(llama_mlp.c_fc1.weight) orig_llama_mlp.w3.weight.copy_(llama_mlp.c_fc2.weight) orig_llama_mlp.w2.weight.copy_(llama_mlp.c_proj.weight) def copy_attention(llama_attn, orig_llama_attn) -> None: n_embd = llama_attn.c_attn.weight.shape[1] orig_llama_attn.wq.weight.copy_(llama_attn.c_attn.weight[:n_embd]) orig_llama_attn.wk.weight.copy_(llama_attn.c_attn.weight[n_embd:-n_embd]) orig_llama_attn.wv.weight.copy_(llama_attn.c_attn.weight[-n_embd:]) orig_llama_attn.wo.weight.copy_(llama_attn.c_proj.weight) def copy_block(llama_block, orig_llama_block) -> None: orig_llama_block.attention_norm.weight.copy_(llama_block.rms_1.scale) copy_attention(llama_block.attn, orig_llama_block.attention) orig_llama_block.ffn_norm.weight.copy_(llama_block.rms_2.scale) copy_mlp(llama_block.mlp, orig_llama_block.feed_forward) def copy_weights(llama_model, orig_llama_model) -> None: orig_llama_model.tok_embeddings.weight.copy_(llama_model.transformer.wte.weight) for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): copy_block(llama_block, orig_llama_block) orig_llama_model.norm.weight.copy_(llama_model.transformer.ln_f.scale) orig_llama_model.output.weight.copy_(llama_model.lm_head.weight) @torch.no_grad() @pytest.mark.parametrize("kv_cache", (False, True)) def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: block_size = 64 vocab_size = 32000 n_layer = 16 n_head = 16 n_embd = 32 batch_size = 3 llama_config = lit_llama.LLaMAConfig( block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd ) orig_llama_config = orig_llama.ModelArgs( dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size, max_batch_size=batch_size, ) seq_len = orig_llama_config.max_seq_len token_sample = torch.randint(0, orig_llama_config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) llama_model = lit_llama.LLaMA(llama_config) llama_model.apply(llama_model._init_weights) orig_llama_model = orig_llama.Transformer(orig_llama_config) copy_weights(llama_model, orig_llama_model) orig_llama_embed = orig_llama_model.tok_embeddings(token_sample) llama_embed = llama_model.transformer.wte(token_sample) assert torch.allclose(orig_llama_embed, llama_embed) llama_rope = llama_model.build_rope_cache(token_sample) llama_mask = llama_model.build_mask_cache(token_sample) orig_llama_mask = torch.full((1, 1, seq_len, seq_len), float("-inf")) orig_llama_mask = torch.triu(orig_llama_mask, diagonal=1) if kv_cache: orig_llama_block_out = orig_llama_model.layers[0]( orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask ) theirs_k_cache = orig_llama_model.layers[0].attention.cache_k theirs_v_cache = orig_llama_model.layers[0].attention.cache_v head_size = n_embd // n_head kv_cache_shape = (batch_size, n_head, block_size, head_size) ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape) (llama_block_out, ours_kv_cache) = llama_model.transformer.h[0]( llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache ) ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3) ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3) torch.testing.assert_close(ours_k_cache, theirs_k_cache) torch.testing.assert_close(ours_v_cache, theirs_v_cache) else: orig_llama_block_out = orig_llama_model.layers[0]( orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask ) (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len) assert torch.allclose(orig_llama_block_out, llama_block_out) expected = orig_llama_model(token_sample, 0) out = llama_model(token_sample) assert torch.allclose(out, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") @torch.no_grad() def test_bfloat16_llama_init(lit_llama, orig_llama) -> None: from lit_llama.utils import EmptyInitOnDevice block_size = 64 vocab_size = 32000 n_layer = 16 n_head = 16 n_embd = 32 llama_config = lit_llama.LLaMAConfig( block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd ) llama_model = lit_llama.LLaMA(llama_config) llama_model.apply(llama_model._init_weights) batch_size = 3 token_sample = torch.randint(0, vocab_size, size=(batch_size, block_size), dtype=torch.int64) expected = llama_model(token_sample) with EmptyInitOnDevice(device="cuda", dtype=torch.bfloat16): llama_model2 = lit_llama.LLaMA(llama_config) llama_model2.load_state_dict(llama_model.state_dict(keep_vars=True)) out = llama_model2(token_sample.cuda()).float().cpu() torch.testing.assert_close(out, expected, atol=5e-3, rtol=1e-3) def copy_adapter_weights(llama_model, orig_llama_model) -> None: # copy the gating parameter for llama_block, orig_llama_block in zip(llama_model.transformer.h, orig_llama_model.layers): if hasattr(llama_block.attn, "gating_factor"): llama_block.attn.gating_factor.copy_(orig_llama_block.attention.gate) # In the original model, there is one embedding layer for all blocks combined orig_adapter_wte = orig_llama_model.adapter_query.weight.reshape( orig_llama_model.params.adapter_layer, orig_llama_model.params.adapter_len, orig_llama_model.params.dim ) # In ours, the embedding layer is split across the individual attention layers index = 0 for llama_block in llama_model.transformer.h: if hasattr(llama_block.attn, "adapter_wte"): llama_block.attn.adapter_wte.weight.copy_(orig_adapter_wte[index]) index += 1 def enable_gate(model): for name, param in model.named_parameters(): if "gating_factor" in name or "gate" in name: param.fill_(1) @torch.no_grad() def test_adapter_parity(orig_llama_adapter): """Test parity between our implementation of LLaMA-Adapter and the reference code.""" import lit_llama.adapter as lit_llama orig_llama = orig_llama_adapter block_size = 32 vocab_size = 100 n_layer = 2 n_head = 4 n_embd = 16 adapter_prompt_length: int = 10 adapter_start_layer: int = 0 llama_config = lit_llama.LLaMAConfig( block_size=block_size, vocab_size=vocab_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, adapter_prompt_length=adapter_prompt_length, adapter_start_layer=adapter_start_layer, ) orig_llama_config = orig_llama.ModelArgs( dim=n_embd, n_layers=n_layer, n_heads=n_head, vocab_size=vocab_size, norm_eps=1e-5, max_seq_len=block_size, adapter_len=adapter_prompt_length, adapter_layer=(n_layer - adapter_start_layer), ) batch_size = 3 token_sample = torch.randint( 0, orig_llama_config.vocab_size, size=(batch_size, orig_llama_config.max_seq_len), dtype=torch.int64 ) llama_model = lit_llama.LLaMA(llama_config) llama_model.apply(llama_model._init_weights) orig_llama_model = orig_llama.Transformer(orig_llama_config) copy_weights(llama_model, orig_llama_model) copy_adapter_weights(llama_model, orig_llama_model) # make the gate non-zero, otherwise the adapter is disabled and the model # identical to regular LLaMA enable_gate(llama_model) enable_gate(orig_llama_model) expected = orig_llama_model(token_sample, 0) out = llama_model(token_sample) assert torch.allclose(out, expected) @pytest.mark.skipif(sys.platform in ("win32", "darwin"), reason="torch.compile not supported on this platform") def test_model_compile(lit_llama): llama_config = lit_llama.LLaMAConfig(block_size=8, vocab_size=8, n_layer=2, n_head=2, n_embd=4) model = lit_llama.LLaMA(llama_config) model.apply(model._init_weights) model = torch.compile(model) sample = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64) for _ in range(3): _ = model(sample)