n00b001's picture
save
c2bdc87 unverified
import pytest
from unittest.mock import MagicMock, patch
from app import get_quantization_recipe, compress_and_upload
import gradio as gr
from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
from llmcompressor.modifiers.awq import AWQModifier
# Mock external dependencies for compress_and_upload
@pytest.fixture
def mock_hf_api():
with patch('app.HfApi') as mock_api:
mock_api_instance = mock_api.return_value
mock_api_instance.create_repo.return_value = "https://huggingface.co/test_user/test_model-AWQ"
yield mock_api_instance
@pytest.fixture
def mock_whoami():
with patch('app.whoami') as mock_whoami_func:
mock_whoami_func.return_value = {"name": "test_user"}
yield mock_whoami_func
@pytest.fixture
def mock_auto_model_for_causal_lm():
with patch('app.AutoModelForCausalLM') as mock_model_class:
mock_model_instance = MagicMock()
mock_model_instance.config.architectures = ["LlamaForCausalLM"]
mock_model_class.from_pretrained.return_value = mock_model_instance
yield mock_model_class
@pytest.fixture
def mock_oneshot():
with patch('app.oneshot') as mock_oneshot_func:
yield mock_oneshot_func
@pytest.fixture
def mock_model_card():
with patch('app.ModelCard') as mock_card_class:
mock_card_instance = MagicMock()
mock_card_class.return_value = mock_card_instance
yield mock_card_class
@pytest.fixture
def mock_gr_oauth_token():
mock_token = MagicMock(spec=gr.OAuthToken)
mock_token.token = "test_token"
return mock_token
# --- Test get_quantization_recipe ---
def test_get_quantization_recipe_awq():
recipe = get_quantization_recipe("AWQ", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], AWQModifier)
def test_get_quantization_recipe_gptq():
recipe = get_quantization_recipe("GPTQ", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
def test_get_quantization_recipe_gptq_mistral():
recipe = get_quantization_recipe("GPTQ", "MistralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
assert recipe[0].sequential_targets == ["MistralDecoderLayer"]
def test_get_quantization_recipe_gptq_mixtral():
recipe = get_quantization_recipe("GPTQ", "MixtralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
assert recipe[0].sequential_targets == ["MixtralDecoderLayer"]
def test_get_quantization_recipe_fp8():
recipe = get_quantization_recipe("FP8", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], QuantizationModifier)
assert recipe[0].scheme == "FP8"
assert recipe[0].ignore == ["lm_head"]
def test_get_quantization_recipe_fp8_mixtral():
recipe = get_quantization_recipe("FP8", "MixtralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], QuantizationModifier)
assert recipe[0].scheme == "FP8"
assert "re:.*block_sparse_moe.gate" in recipe[0].ignore
def test_get_quantization_recipe_unsupported():
with pytest.raises(ValueError, match="Unsupported quantization method: INVALID"):
get_quantization_recipe("INVALID", "LlamaForCausalLM")
# --- Test compress_and_upload ---
def test_compress_and_upload_no_model_id(mock_gr_oauth_token):
with pytest.raises(gr.Error, match="Please select a model from the search bar."):
compress_and_upload("", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
def test_compress_and_upload_no_oauth_token():
with pytest.raises(gr.Error, match="Authentication error. Please log in to continue."):
compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", None)
def test_compress_and_upload_success(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_oneshot,
mock_model_card,
mock_gr_oauth_token,
):
model_id = "org/test_model"
quant_method = "AWQ"
model_type_selection = "Auto-detect (recommended)"
result = compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)
mock_whoami.assert_called_once_with(token="test_token")
# The device_map and torch_dtype should depend on CUDA availability
import torch
if torch.cuda.is_available():
expected_torch_dtype = torch.float16
expected_device_map = "auto"
else:
expected_torch_dtype = "auto"
expected_device_map = "cpu"
mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
)
mock_oneshot.assert_called_once()
assert mock_oneshot.call_args[1]["model"] == mock_auto_model_for_causal_lm.from_pretrained.return_value
assert mock_oneshot.call_args[1]["recipe"] is not None
assert mock_oneshot.call_args[1]["output_dir"] == f"test_model-{quant_method}"
mock_hf_api.create_repo.assert_called_once_with(
repo_id=f"test_user/test_model-{quant_method}", exist_ok=True
)
mock_hf_api.upload_folder.assert_called_once_with(
folder_path=f"test_model-{quant_method}",
repo_id=f"test_user/test_model-{quant_method}",
commit_message=f"Upload {quant_method} compressed model",
)
mock_model_card.assert_called_once()
mock_model_card.return_value.push_to_hub.assert_called_once_with(
f"test_user/test_model-{quant_method}", token="test_token"
)
assert "✅ Success!" in result
assert "https://huggingface.co/test_user/test_model-AWQ" in result
def test_compress_and_upload_with_trust_remote_code(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_oneshot,
mock_model_card,
mock_gr_oauth_token,
):
model_id = "org/test_model"
quant_method = "AWQ"
model_type_selection = "Auto-detect (recommended)"
compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)
# The device_map and torch_dtype should depend on CUDA availability
import torch
if torch.cuda.is_available():
expected_torch_dtype = torch.float16
expected_device_map = "auto"
else:
expected_torch_dtype = "auto"
expected_device_map = "cpu"
mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
)
def test_compress_and_upload_model_no_architecture(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = []
with pytest.raises(gr.Error, match="Could not determine model architecture."):
compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
def test_compress_and_upload_generic_exception(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_whoami.side_effect = Exception("Network error")
result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
assert "❌ ERROR" in result
assert "Network error" in result
def test_compress_and_upload_unrecognized_architecture(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = ["UnrecognizedArchitecture"]
result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
assert "❌ ERROR" in result
assert "AWQ quantization is only supported for LlamaForCausalLM architectures, got UnrecognizedArchitecture" in result