Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test script for quantization functionality | |
""" | |
import os | |
import sys | |
import tempfile | |
import shutil | |
from pathlib import Path | |
import logging | |
# Add the project root to the path | |
project_root = Path(__file__).parent.parent | |
sys.path.append(str(project_root)) | |
from scripts.model_tonic.quantize_model import ModelQuantizer | |
def test_quantization_imports(): | |
"""Test that all required imports are available""" | |
try: | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
from torchao.quantization import ( | |
Int8WeightOnlyConfig, | |
Int4WeightOnlyConfig, | |
Int8DynamicActivationInt8WeightConfig | |
) | |
from torchao.dtypes import Int4CPULayout | |
print("β All quantization imports successful") | |
return True | |
except ImportError as e: | |
print(f"β Import error: {e}") | |
return False | |
def test_quantizer_initialization(): | |
"""Test quantizer initialization""" | |
try: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Create a dummy model directory | |
model_dir = Path(temp_dir) / "dummy_model" | |
model_dir.mkdir() | |
# Create minimal model files | |
(model_dir / "config.json").write_text('{"model_type": "test"}') | |
(model_dir / "pytorch_model.bin").write_text('dummy') | |
quantizer = ModelQuantizer( | |
model_path=str(model_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
print("β Quantizer initialization successful") | |
return True | |
except Exception as e: | |
print(f"β Quantizer initialization failed: {e}") | |
return False | |
def test_quantization_config_creation(): | |
"""Test quantization configuration creation""" | |
try: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_dir = Path(temp_dir) / "dummy_model" | |
model_dir.mkdir() | |
(model_dir / "config.json").write_text('{"model_type": "test"}') | |
(model_dir / "pytorch_model.bin").write_text('dummy') | |
quantizer = ModelQuantizer( | |
model_path=str(model_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
# Test int8 config | |
config_int8 = quantizer.create_quantization_config("int8_weight_only", 128) | |
print("β int8 config creation successful") | |
# Test int4 config | |
config_int4 = quantizer.create_quantization_config("int4_weight_only", 128) | |
print("β int4 config creation successful") | |
return True | |
except Exception as e: | |
print(f"β Config creation failed: {e}") | |
return False | |
def test_model_validation(): | |
"""Test model path validation""" | |
try: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
# Test with valid model | |
model_dir = Path(temp_dir) / "valid_model" | |
model_dir.mkdir() | |
(model_dir / "config.json").write_text('{"model_type": "test"}') | |
(model_dir / "pytorch_model.bin").write_text('dummy') | |
quantizer = ModelQuantizer( | |
model_path=str(model_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
if quantizer.validate_model_path(): | |
print("β Valid model validation successful") | |
else: | |
print("β Valid model validation failed") | |
return False | |
# Test with invalid model | |
invalid_dir = Path(temp_dir) / "invalid_model" | |
invalid_dir.mkdir() | |
# Missing required files | |
quantizer_invalid = ModelQuantizer( | |
model_path=str(invalid_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
if not quantizer_invalid.validate_model_path(): | |
print("β Invalid model validation successful") | |
else: | |
print("β Invalid model validation failed") | |
return False | |
return True | |
except Exception as e: | |
print(f"β Model validation test failed: {e}") | |
return False | |
def test_quantized_model_card_creation(): | |
"""Test quantized model card creation""" | |
try: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_dir = Path(temp_dir) / "dummy_model" | |
model_dir.mkdir() | |
(model_dir / "config.json").write_text('{"model_type": "test"}') | |
(model_dir / "pytorch_model.bin").write_text('dummy') | |
quantizer = ModelQuantizer( | |
model_path=str(model_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
# Test int8 model card | |
card_int8 = quantizer.create_quantized_model_card("int8_weight_only", "test/model") | |
if "int8_weight_only" in card_int8 and "GPU" in card_int8: | |
print("β int8 model card creation successful") | |
else: | |
print("β int8 model card creation failed") | |
return False | |
# Test int4 model card | |
card_int4 = quantizer.create_quantized_model_card("int4_weight_only", "test/model") | |
if "int4_weight_only" in card_int4 and "CPU" in card_int4: | |
print("β int4 model card creation successful") | |
else: | |
print("β int4 model card creation failed") | |
return False | |
return True | |
except Exception as e: | |
print(f"β Model card creation test failed: {e}") | |
return False | |
def test_quantized_readme_creation(): | |
"""Test quantized README creation""" | |
try: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
model_dir = Path(temp_dir) / "dummy_model" | |
model_dir.mkdir() | |
(model_dir / "config.json").write_text('{"model_type": "test"}') | |
(model_dir / "pytorch_model.bin").write_text('dummy') | |
quantizer = ModelQuantizer( | |
model_path=str(model_dir), | |
repo_name="test/test-quantized", | |
token="dummy_token" | |
) | |
# Test int8 README | |
readme_int8 = quantizer.create_quantized_readme("int8_weight_only", "test/model") | |
if "int8_weight_only" in readme_int8 and "GPU optimized" in readme_int8: | |
print("β int8 README creation successful") | |
else: | |
print("β int8 README creation failed") | |
return False | |
# Test int4 README | |
readme_int4 = quantizer.create_quantized_readme("int4_weight_only", "test/model") | |
if "int4_weight_only" in readme_int4 and "CPU optimized" in readme_int4: | |
print("β int4 README creation successful") | |
else: | |
print("β int4 README creation failed") | |
return False | |
return True | |
except Exception as e: | |
print(f"β README creation test failed: {e}") | |
return False | |
def main(): | |
"""Run all quantization tests""" | |
print("π§ͺ Running Quantization Tests") | |
print("=" * 40) | |
tests = [ | |
("Import Test", test_quantization_imports), | |
("Initialization Test", test_quantizer_initialization), | |
("Config Creation Test", test_quantization_config_creation), | |
("Model Validation Test", test_model_validation), | |
("Model Card Test", test_quantized_model_card_creation), | |
("README Test", test_quantized_readme_creation), | |
] | |
passed = 0 | |
total = len(tests) | |
for test_name, test_func in tests: | |
print(f"\nπ Running {test_name}...") | |
try: | |
if test_func(): | |
passed += 1 | |
print(f"β {test_name} passed") | |
else: | |
print(f"β {test_name} failed") | |
except Exception as e: | |
print(f"β {test_name} failed with exception: {e}") | |
print("\n" + "=" * 40) | |
print(f"π Test Results: {passed}/{total} tests passed") | |
if passed == total: | |
print("π All quantization tests passed!") | |
return 0 | |
else: | |
print("β οΈ Some tests failed. Check the output above.") | |
return 1 | |
if __name__ == "__main__": | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
exit(main()) |