Adding _bnb_4bit_dequantize_and_rescale in ` modeling_rwkv5.py ` based on the provided docmentations on github

#7
Files changed (1) hide show
  1. modeling_rwkv5.py +21 -0
modeling_rwkv5.py CHANGED
@@ -852,3 +852,24 @@ class Rwkv5ForCausalLM(Rwkv5PreTrainedModel):
852
  hidden_states=rwkv_outputs.hidden_states,
853
  attentions=rwkv_outputs.attentions,
854
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  hidden_states=rwkv_outputs.hidden_states,
853
  attentions=rwkv_outputs.attentions,
854
  )
855
+
856
+ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
857
+ r"""
858
+ Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
859
+ be quantized again.
860
+ """
861
+ if not is_bitsandbytes_available():
862
+ raise ImportError("Please install bitsandbytes to use this method.")
863
+ import bitsandbytes as bnb
864
+
865
+ dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
866
+
867
+ dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
868
+
869
+ # re-quantize the model:
870
+ # we need to put it first on CPU then back to the device
871
+ # this will create an overhead :/
872
+ # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
873
+ # bugs with bnb
874
+ quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
875
+ setattr(target_layer, "weight", quant_weight)