Ayesha-Majeed commited on
Commit
ba7cc56
·
verified ·
1 Parent(s): 28e9444

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +11 -1
binary_segmentation.py CHANGED
@@ -448,6 +448,11 @@ class BinarySegmenter:
448
  self.model = None
449
  self.transform = None
450
  self._load_model()
 
 
 
 
 
451
 
452
  def _load_model(self):
453
  """Load the specified segmentation model"""
@@ -497,7 +502,9 @@ class BinarySegmenter:
497
  self.model = AutoModelForImageSegmentation.from_pretrained(
498
  'ZhengPeng7/BiRefNet',
499
  trust_remote_code=True,
500
- cache_dir=str(self.cache_dir)
 
 
501
  )
502
 
503
  self.transform = transforms.Compose([
@@ -559,6 +566,9 @@ class BinarySegmenter:
559
 
560
  # Transform
561
  input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
 
 
 
562
 
563
  # Inference
564
  with torch.no_grad():
 
448
  self.model = None
449
  self.transform = None
450
  self._load_model()
451
+
452
+ if DEVICE == "cpu":
453
+ self.model = self.model.float()
454
+ self.model.to(DEVICE)
455
+ self.model.eval()
456
 
457
  def _load_model(self):
458
  """Load the specified segmentation model"""
 
502
  self.model = AutoModelForImageSegmentation.from_pretrained(
503
  'ZhengPeng7/BiRefNet',
504
  trust_remote_code=True,
505
+ cache_dir=str(self.cache_dir),
506
+ torch_dtype=torch.float32,
507
+ low_cpu_mem_usage=False
508
  )
509
 
510
  self.transform = transforms.Compose([
 
566
 
567
  # Transform
568
  input_tensor = self.transform(image_pil).unsqueeze(0).to(DEVICE)
569
+ if DEVICE == "cpu":
570
+ input_tensor = input_tensor.float()
571
+
572
 
573
  # Inference
574
  with torch.no_grad():