SmolFactory / tests /test_quantization.py
Tonic's picture
adds sft , quantization, better readmes
40fd629 verified
#!/usr/bin/env python3
"""
Test script for quantization functionality
"""
import os
import sys
import tempfile
import shutil
from pathlib import Path
import logging
# Add the project root to the path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from scripts.model_tonic.quantize_model import ModelQuantizer
def test_quantization_imports():
"""Test that all required imports are available"""
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import (
Int8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig
)
from torchao.dtypes import Int4CPULayout
print("βœ… All quantization imports successful")
return True
except ImportError as e:
print(f"❌ Import error: {e}")
return False
def test_quantizer_initialization():
"""Test quantizer initialization"""
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Create a dummy model directory
model_dir = Path(temp_dir) / "dummy_model"
model_dir.mkdir()
# Create minimal model files
(model_dir / "config.json").write_text('{"model_type": "test"}')
(model_dir / "pytorch_model.bin").write_text('dummy')
quantizer = ModelQuantizer(
model_path=str(model_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
print("βœ… Quantizer initialization successful")
return True
except Exception as e:
print(f"❌ Quantizer initialization failed: {e}")
return False
def test_quantization_config_creation():
"""Test quantization configuration creation"""
try:
with tempfile.TemporaryDirectory() as temp_dir:
model_dir = Path(temp_dir) / "dummy_model"
model_dir.mkdir()
(model_dir / "config.json").write_text('{"model_type": "test"}')
(model_dir / "pytorch_model.bin").write_text('dummy')
quantizer = ModelQuantizer(
model_path=str(model_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
# Test int8 config
config_int8 = quantizer.create_quantization_config("int8_weight_only", 128)
print("βœ… int8 config creation successful")
# Test int4 config
config_int4 = quantizer.create_quantization_config("int4_weight_only", 128)
print("βœ… int4 config creation successful")
return True
except Exception as e:
print(f"❌ Config creation failed: {e}")
return False
def test_model_validation():
"""Test model path validation"""
try:
with tempfile.TemporaryDirectory() as temp_dir:
# Test with valid model
model_dir = Path(temp_dir) / "valid_model"
model_dir.mkdir()
(model_dir / "config.json").write_text('{"model_type": "test"}')
(model_dir / "pytorch_model.bin").write_text('dummy')
quantizer = ModelQuantizer(
model_path=str(model_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
if quantizer.validate_model_path():
print("βœ… Valid model validation successful")
else:
print("❌ Valid model validation failed")
return False
# Test with invalid model
invalid_dir = Path(temp_dir) / "invalid_model"
invalid_dir.mkdir()
# Missing required files
quantizer_invalid = ModelQuantizer(
model_path=str(invalid_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
if not quantizer_invalid.validate_model_path():
print("βœ… Invalid model validation successful")
else:
print("❌ Invalid model validation failed")
return False
return True
except Exception as e:
print(f"❌ Model validation test failed: {e}")
return False
def test_quantized_model_card_creation():
"""Test quantized model card creation"""
try:
with tempfile.TemporaryDirectory() as temp_dir:
model_dir = Path(temp_dir) / "dummy_model"
model_dir.mkdir()
(model_dir / "config.json").write_text('{"model_type": "test"}')
(model_dir / "pytorch_model.bin").write_text('dummy')
quantizer = ModelQuantizer(
model_path=str(model_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
# Test int8 model card
card_int8 = quantizer.create_quantized_model_card("int8_weight_only", "test/model")
if "int8_weight_only" in card_int8 and "GPU" in card_int8:
print("βœ… int8 model card creation successful")
else:
print("❌ int8 model card creation failed")
return False
# Test int4 model card
card_int4 = quantizer.create_quantized_model_card("int4_weight_only", "test/model")
if "int4_weight_only" in card_int4 and "CPU" in card_int4:
print("βœ… int4 model card creation successful")
else:
print("❌ int4 model card creation failed")
return False
return True
except Exception as e:
print(f"❌ Model card creation test failed: {e}")
return False
def test_quantized_readme_creation():
"""Test quantized README creation"""
try:
with tempfile.TemporaryDirectory() as temp_dir:
model_dir = Path(temp_dir) / "dummy_model"
model_dir.mkdir()
(model_dir / "config.json").write_text('{"model_type": "test"}')
(model_dir / "pytorch_model.bin").write_text('dummy')
quantizer = ModelQuantizer(
model_path=str(model_dir),
repo_name="test/test-quantized",
token="dummy_token"
)
# Test int8 README
readme_int8 = quantizer.create_quantized_readme("int8_weight_only", "test/model")
if "int8_weight_only" in readme_int8 and "GPU optimized" in readme_int8:
print("βœ… int8 README creation successful")
else:
print("❌ int8 README creation failed")
return False
# Test int4 README
readme_int4 = quantizer.create_quantized_readme("int4_weight_only", "test/model")
if "int4_weight_only" in readme_int4 and "CPU optimized" in readme_int4:
print("βœ… int4 README creation successful")
else:
print("❌ int4 README creation failed")
return False
return True
except Exception as e:
print(f"❌ README creation test failed: {e}")
return False
def main():
"""Run all quantization tests"""
print("πŸ§ͺ Running Quantization Tests")
print("=" * 40)
tests = [
("Import Test", test_quantization_imports),
("Initialization Test", test_quantizer_initialization),
("Config Creation Test", test_quantization_config_creation),
("Model Validation Test", test_model_validation),
("Model Card Test", test_quantized_model_card_creation),
("README Test", test_quantized_readme_creation),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
print(f"\nπŸ“‹ Running {test_name}...")
try:
if test_func():
passed += 1
print(f"βœ… {test_name} passed")
else:
print(f"❌ {test_name} failed")
except Exception as e:
print(f"❌ {test_name} failed with exception: {e}")
print("\n" + "=" * 40)
print(f"πŸ“Š Test Results: {passed}/{total} tests passed")
if passed == total:
print("πŸŽ‰ All quantization tests passed!")
return 0
else:
print("⚠️ Some tests failed. Check the output above.")
return 1
if __name__ == "__main__":
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
exit(main())