nxphi47 commited on
Commit
5d92b76
1 Parent(s): 2328811

Update multipurpose_chatbot/engines/sealmmm_engine.py

Browse files
multipurpose_chatbot/engines/sealmmm_engine.py CHANGED
@@ -215,7 +215,8 @@ class SeaLMMMv0Engine(TransformersEngine):
215
  with torch.no_grad():
216
  inputs = self.processor(prompt, images, return_tensors='pt')
217
  # inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None}
218
- inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None}
 
219
  num_tokens = self.get_multimodal_tokens(prompt, image_paths)
220
  # non-streaming generation
221
  # output = self._model.generate(
 
215
  with torch.no_grad():
216
  inputs = self.processor(prompt, images, return_tensors='pt')
217
  # inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None}
218
+ # model.device
219
+ inputs = {k: v.to(self._model.device) for k, v in inputs.items() if v is not None}
220
  num_tokens = self.get_multimodal_tokens(prompt, image_paths)
221
  # non-streaming generation
222
  # output = self._model.generate(