Upload mmalaya_arch.py
Browse files- mmalaya_arch.py +3 -3
mmalaya_arch.py
CHANGED
@@ -303,15 +303,15 @@ class MMAlayaMetaForCausalLM(ABC):
|
|
303 |
stopping_criteria = KeywordsStoppingCriteria(
|
304 |
[conv.sep2],
|
305 |
tokenizer,
|
306 |
-
torch.tensor(input_ids, dtype=torch.long),
|
307 |
)
|
308 |
# 加载图像
|
309 |
image_processor = model.get_vision_tower().image_processor
|
310 |
-
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half()
|
311 |
|
312 |
if return_tensors is not None:
|
313 |
if return_tensors == 'pt':
|
314 |
-
return torch.tensor(input_ids, dtype=torch.long), image_tensor, stopping_criteria
|
315 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
316 |
|
317 |
return input_ids, image_tensor, stopping_criteria
|
|
|
303 |
stopping_criteria = KeywordsStoppingCriteria(
|
304 |
[conv.sep2],
|
305 |
tokenizer,
|
306 |
+
torch.tensor(input_ids, dtype=torch.long).unsqueeze(0),
|
307 |
)
|
308 |
# 加载图像
|
309 |
image_processor = model.get_vision_tower().image_processor
|
310 |
+
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half()
|
311 |
|
312 |
if return_tensors is not None:
|
313 |
if return_tensors == 'pt':
|
314 |
+
return torch.tensor(input_ids, dtype=torch.long).unsqueeze(0), image_tensor, stopping_criteria
|
315 |
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
316 |
|
317 |
return input_ids, image_tensor, stopping_criteria
|