Camil Ziane commited on
Commit
412e60a
1 Parent(s): 340a0f3

fix F.conv2d "slow_conv2d_cpu"

Browse files
TinyLLaVA_Factory/tinyllava/serve/app.py CHANGED
@@ -96,7 +96,7 @@ def get_response(params):
96
  # image = [load_image_from_base64(img) for img in images][0]
97
  image = images[0][0]
98
  image = image_processor(image)
99
- image = image.unsqueeze(0).to(model.device, dtype=torch.float16)
100
  num_image_tokens = getattr(model.vision_tower._vision_tower, "num_patches", 336)
101
  else:
102
  image = None
@@ -351,6 +351,8 @@ if __name__ == "__main__":
351
  load_8bit=args.load_8bit
352
  )
353
  model.to(args.device)
 
 
354
  image_processor = ImagePreprocess(image_processor, model.config)
355
  text_processor = TextPreprocess(tokenizer, args.conv_mode)
356
  demo = build_demo()
 
96
  # image = [load_image_from_base64(img) for img in images][0]
97
  image = images[0][0]
98
  image = image_processor(image)
99
+ image = image.unsqueeze(0).to(model.device, dtype=torch.float32)
100
  num_image_tokens = getattr(model.vision_tower._vision_tower, "num_patches", 336)
101
  else:
102
  image = None
 
351
  load_8bit=args.load_8bit
352
  )
353
  model.to(args.device)
354
+ model =model.to(torch.float32)
355
+
356
  image_processor = ImagePreprocess(image_processor, model.config)
357
  text_processor = TextPreprocess(tokenizer, args.conv_mode)
358
  demo = build_demo()