fix bfloat mismatch bug when model loading using half()
#10
by
secularbird
- opened
modeling_internlm_xcomposer2.py
CHANGED
@@ -123,7 +123,7 @@ class InternLMXComposer2ForCausalLM(InternLM2PreTrainedModel):
|
|
123 |
print ('Batch Size >1 is not supported.')
|
124 |
assert 0
|
125 |
#print (img_embeds.shape)
|
126 |
-
img_embeds = self.vision_proj(img_embeds)
|
127 |
atts_img = torch.ones(
|
128 |
img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device)
|
129 |
|
|
|
123 |
print ('Batch Size >1 is not supported.')
|
124 |
assert 0
|
125 |
#print (img_embeds.shape)
|
126 |
+
img_embeds = self.vision_proj(img_embeds.to(self.dtype))
|
127 |
atts_img = torch.ones(
|
128 |
img_embeds.size()[:-1], dtype=torch.long).to(img_embeds.device)
|
129 |
|