File size: 1,933 Bytes
43a66d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
--- ori.openvino_backend.py	2024-03-21 12:35:49.552105914 -0700
+++ openvino_backend.py	2024-03-21 13:33:00.382049221 -0700
@@ -33,6 +33,8 @@
 from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
 from nncf.quantization.algorithms.weight_compression.weight_lowering import compress_weight
 
+from collections import OrderedDict
+import torch
 
 class OVWeightCompressionAlgoBackend(WeightCompressionAlgoBackend):
     def __init__(self, model: ov.Model):
@@ -123,6 +125,8 @@
     def transform_model(
         self, model: ov.Model, graph: NNCFGraph, weight_compression_parameters: Iterable[WeightCompressionParameters]
     ) -> ov.Model:
+        debug_wc = OrderedDict()
+
         for wc_params in weight_compression_parameters:
             compression_config = wc_params.compression_config
             if compression_config.mode == CompressWeightsMode.NF4:
@@ -149,6 +153,13 @@
             weight = Tensor(get_const_value(const_node))
             original_shape = weight.shape
             compressed_weight = compress_weight(weight, wc_params.reduction_axes, compression_config)
+            dkey = ".".join(const_node_name.split(".")[2:-1])
+            debug_wc[dkey] = {}
+            debug_wc[dkey]['original_shape'] = original_shape
+            debug_wc[dkey]['q_dtype'] = compression_dtype.type_name
+            debug_wc[dkey]['q_weight'] = compressed_weight.tensor.data
+            debug_wc[dkey]['q_scale'] = compressed_weight.scale.data
+            debug_wc[dkey]['q_zero_point'] = compressed_weight.zero_point.data
 
             compressed_const = opset.constant(
                 compressed_weight.tensor.data, dtype=compression_dtype, name=const_node_name
@@ -182,6 +193,7 @@
         # reset name_to_node_mapping
         self.name_to_node_mapping = None
 
+        torch.save(debug_wc, 'llama-2-chat-7b_r0.8_g128.pth')
         return model
 
     @staticmethod