zxdu20 commited on
Commit
5fc46d2
1 Parent(s): bfb1a8f

Fix embedding quantization

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +10 -5
modeling_chatglm.py CHANGED
@@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1408
 
1409
  self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
1410
 
 
 
 
 
 
1411
  if quantize_embeddings:
1412
  logger.info("Applying quantization to embeddings")
1413
  self.transformer.word_embeddings = QuantizedEmbedding(
@@ -1415,11 +1420,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1415
  weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
1416
  num_embeddings=self.transformer.word_embeddings.num_embeddings,
1417
  embedding_dim=self.transformer.word_embeddings.embedding_dim,
1418
- dtype=torch.half,
1419
- empty_init=True,
1420
  device=self.transformer.word_embeddings.weight.device,
1421
  )
1422
- self.lm_head = QuantizedLinear(
1423
  weight_bit_width=bits,
1424
  weight_tensor=self.lm_head.weight.to(self.device),
1425
  bias_tensor=None,
@@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1428
  bias=False,
1429
  quantized_weight=self.transformer.word_embeddings.weight,
1430
  quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
1431
- dtype=torch.half,
1432
- empty_init=True,
1433
  device=self.lm_head.weight.device,
1434
  )
1435
 
 
1408
 
1409
  self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
1410
 
1411
+ if self.device == torch.device("cpu"):
1412
+ dtype = torch.float32
1413
+ else:
1414
+ dtype = torch.half
1415
+
1416
  if quantize_embeddings:
1417
  logger.info("Applying quantization to embeddings")
1418
  self.transformer.word_embeddings = QuantizedEmbedding(
 
1420
  weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
1421
  num_embeddings=self.transformer.word_embeddings.num_embeddings,
1422
  embedding_dim=self.transformer.word_embeddings.embedding_dim,
1423
+ dtype=dtype,
1424
+ empty_init=empty_init,
1425
  device=self.transformer.word_embeddings.weight.device,
1426
  )
1427
+ self.lm_head = QuantizedLinear(
1428
  weight_bit_width=bits,
1429
  weight_tensor=self.lm_head.weight.to(self.device),
1430
  bias_tensor=None,
 
1433
  bias=False,
1434
  quantized_weight=self.transformer.word_embeddings.weight,
1435
  quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
1436
+ dtype=dtype,
1437
+ empty_init=empty_init,
1438
  device=self.lm_head.weight.device,
1439
  )
1440