File size: 8,878 Bytes
40fd629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#!/usr/bin/env python3
"""
Test script for quantization functionality
"""

import os
import sys
import tempfile
import shutil
from pathlib import Path
import logging

# Add the project root to the path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from scripts.model_tonic.quantize_model import ModelQuantizer

def test_quantization_imports():
    """Test that all required imports are available"""
    try:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
        from torchao.quantization import (
            Int8WeightOnlyConfig,
            Int4WeightOnlyConfig,
            Int8DynamicActivationInt8WeightConfig
        )
        from torchao.dtypes import Int4CPULayout
        print("βœ… All quantization imports successful")
        return True
    except ImportError as e:
        print(f"❌ Import error: {e}")
        return False

def test_quantizer_initialization():
    """Test quantizer initialization"""
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            # Create a dummy model directory
            model_dir = Path(temp_dir) / "dummy_model"
            model_dir.mkdir()
            
            # Create minimal model files
            (model_dir / "config.json").write_text('{"model_type": "test"}')
            (model_dir / "pytorch_model.bin").write_text('dummy')
            
            quantizer = ModelQuantizer(
                model_path=str(model_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            print("βœ… Quantizer initialization successful")
            return True
    except Exception as e:
        print(f"❌ Quantizer initialization failed: {e}")
        return False

def test_quantization_config_creation():
    """Test quantization configuration creation"""
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            model_dir = Path(temp_dir) / "dummy_model"
            model_dir.mkdir()
            (model_dir / "config.json").write_text('{"model_type": "test"}')
            (model_dir / "pytorch_model.bin").write_text('dummy')
            
            quantizer = ModelQuantizer(
                model_path=str(model_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            # Test int8 config
            config_int8 = quantizer.create_quantization_config("int8_weight_only", 128)
            print("βœ… int8 config creation successful")
            
            # Test int4 config
            config_int4 = quantizer.create_quantization_config("int4_weight_only", 128)
            print("βœ… int4 config creation successful")
            
            return True
    except Exception as e:
        print(f"❌ Config creation failed: {e}")
        return False

def test_model_validation():
    """Test model path validation"""
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            # Test with valid model
            model_dir = Path(temp_dir) / "valid_model"
            model_dir.mkdir()
            (model_dir / "config.json").write_text('{"model_type": "test"}')
            (model_dir / "pytorch_model.bin").write_text('dummy')
            
            quantizer = ModelQuantizer(
                model_path=str(model_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            if quantizer.validate_model_path():
                print("βœ… Valid model validation successful")
            else:
                print("❌ Valid model validation failed")
                return False
            
            # Test with invalid model
            invalid_dir = Path(temp_dir) / "invalid_model"
            invalid_dir.mkdir()
            # Missing required files
            
            quantizer_invalid = ModelQuantizer(
                model_path=str(invalid_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            if not quantizer_invalid.validate_model_path():
                print("βœ… Invalid model validation successful")
            else:
                print("❌ Invalid model validation failed")
                return False
            
            return True
    except Exception as e:
        print(f"❌ Model validation test failed: {e}")
        return False

def test_quantized_model_card_creation():
    """Test quantized model card creation"""
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            model_dir = Path(temp_dir) / "dummy_model"
            model_dir.mkdir()
            (model_dir / "config.json").write_text('{"model_type": "test"}')
            (model_dir / "pytorch_model.bin").write_text('dummy')
            
            quantizer = ModelQuantizer(
                model_path=str(model_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            # Test int8 model card
            card_int8 = quantizer.create_quantized_model_card("int8_weight_only", "test/model")
            if "int8_weight_only" in card_int8 and "GPU" in card_int8:
                print("βœ… int8 model card creation successful")
            else:
                print("❌ int8 model card creation failed")
                return False
            
            # Test int4 model card
            card_int4 = quantizer.create_quantized_model_card("int4_weight_only", "test/model")
            if "int4_weight_only" in card_int4 and "CPU" in card_int4:
                print("βœ… int4 model card creation successful")
            else:
                print("❌ int4 model card creation failed")
                return False
            
            return True
    except Exception as e:
        print(f"❌ Model card creation test failed: {e}")
        return False

def test_quantized_readme_creation():
    """Test quantized README creation"""
    try:
        with tempfile.TemporaryDirectory() as temp_dir:
            model_dir = Path(temp_dir) / "dummy_model"
            model_dir.mkdir()
            (model_dir / "config.json").write_text('{"model_type": "test"}')
            (model_dir / "pytorch_model.bin").write_text('dummy')
            
            quantizer = ModelQuantizer(
                model_path=str(model_dir),
                repo_name="test/test-quantized",
                token="dummy_token"
            )
            
            # Test int8 README
            readme_int8 = quantizer.create_quantized_readme("int8_weight_only", "test/model")
            if "int8_weight_only" in readme_int8 and "GPU optimized" in readme_int8:
                print("βœ… int8 README creation successful")
            else:
                print("❌ int8 README creation failed")
                return False
            
            # Test int4 README
            readme_int4 = quantizer.create_quantized_readme("int4_weight_only", "test/model")
            if "int4_weight_only" in readme_int4 and "CPU optimized" in readme_int4:
                print("βœ… int4 README creation successful")
            else:
                print("❌ int4 README creation failed")
                return False
            
            return True
    except Exception as e:
        print(f"❌ README creation test failed: {e}")
        return False

def main():
    """Run all quantization tests"""
    print("πŸ§ͺ Running Quantization Tests")
    print("=" * 40)
    
    tests = [
        ("Import Test", test_quantization_imports),
        ("Initialization Test", test_quantizer_initialization),
        ("Config Creation Test", test_quantization_config_creation),
        ("Model Validation Test", test_model_validation),
        ("Model Card Test", test_quantized_model_card_creation),
        ("README Test", test_quantized_readme_creation),
    ]
    
    passed = 0
    total = len(tests)
    
    for test_name, test_func in tests:
        print(f"\nπŸ“‹ Running {test_name}...")
        try:
            if test_func():
                passed += 1
                print(f"βœ… {test_name} passed")
            else:
                print(f"❌ {test_name} failed")
        except Exception as e:
            print(f"❌ {test_name} failed with exception: {e}")
    
    print("\n" + "=" * 40)
    print(f"πŸ“Š Test Results: {passed}/{total} tests passed")
    
    if passed == total:
        print("πŸŽ‰ All quantization tests passed!")
        return 0
    else:
        print("⚠️ Some tests failed. Check the output above.")
        return 1

if __name__ == "__main__":
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    exit(main())