fix bfloat mismatch bug when model loading using half()

#10
Files changed (1) hide show
  1. modeling_internlm_xcomposer2.py +1 -1
modeling_internlm_xcomposer2.py CHANGED
@@ -123,7 +123,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
123
  print ('Batch Size >1 is not supported.')
124
  assert 0
125
  #print (img_embeds.shape)
126
- img_embeds = self.vision_proj(img_embeds)
127
  atts_img = torch.ones(
128
  img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device)
129
 
 
123
  print ('Batch Size >1 is not supported.')
124
  assert 0
125
  #print (img_embeds.shape)
126
+ img_embeds = self.vision_proj(img_embeds.to(self.dtype))
127
  atts_img = torch.ones(
128
  img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device)
129