| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
|
|
| import pytest |
| import torch |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| from llamafactory.model.model_utils.misc import find_expanded_modules |
|
|
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
| @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") |
| def test_expanded_modules(): |
| config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| with torch.device("meta"): |
| model = AutoModelForCausalLM.from_config(config) |
|
|
| expanded_modules = find_expanded_modules(model, ["q_proj", "v_proj"], num_layer_trainable=4) |
| assert expanded_modules == [ |
| "model.layers.7.self_attn.q_proj", |
| "model.layers.7.self_attn.v_proj", |
| "model.layers.15.self_attn.q_proj", |
| "model.layers.15.self_attn.v_proj", |
| "model.layers.23.self_attn.q_proj", |
| "model.layers.23.self_attn.v_proj", |
| "model.layers.31.self_attn.q_proj", |
| "model.layers.31.self_attn.v_proj", |
| ] |
|
|