VikramSingh178 commited on
Commit
88e9206
1 Parent(s): a76141d

refactor: Update import statement for accelerator and image augmentation functionality

Browse files
scripts/__pycache__/config.cpython-312.pyc ADDED
Binary file (3.22 kB). View file
 
scripts/utils.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
  from ultralytics import YOLO
3
- from transformers import SamModel,SamProcessor
4
  import numpy as np
5
- from PIL import Image
6
  from config import SEGMENTATION_MODEL_NAME
7
 
8
 
@@ -14,15 +14,14 @@ def accelerator():
14
  str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
15
  """
16
  if torch.cuda.is_available():
17
- device = 'cuda'
18
  elif torch.backends.mps.is_available():
19
- device = 'mps'
20
  else:
21
- device = 'cpu'
22
  return device
23
 
24
 
25
-
26
  class ImageAugmentation:
27
  """
28
  Class for centering an image on a white background using ROI.
@@ -54,7 +53,10 @@ class ImageAugmentation:
54
  w, h = self.background_size
55
  bg = np.ones((h, w, 3), dtype=np.uint8) * 255 # White background
56
  x, y, roi_w, roi_h = roi
57
- bg[(h - roi_h) // 2:(h - roi_h) // 2 + roi_h, (w - roi_w) // 2:(w - roi_w) // 2 + roi_w] = image
 
 
 
58
  return bg
59
 
60
  def detect_region_of_interest(self, image):
@@ -69,11 +71,12 @@ class ImageAugmentation:
69
  """
70
  # Convert image to grayscale
71
  grayscale_image = np.array(Image.fromarray(image).convert("L"))
72
-
73
  # Calculate bounding box of non-zero region
74
  bbox = Image.fromarray(grayscale_image).getbbox()
75
  return bbox
76
 
 
77
  def generate_bbox(image):
78
  """
79
  Generate bounding box for the input image.
@@ -85,17 +88,39 @@ def generate_bbox(image):
85
  tuple: Bounding box coordinates (x, y, width, height).
86
  """
87
  # Load YOLOv5 model
88
- model = YOLO("yolov8s.pt")
89
  results = model(image)
90
  # Get bounding box coordinates
91
  bbox = results[0].boxes.xyxy.int().tolist()
92
  return bbox
93
 
94
- def generate_mask():
95
- model = SamModel.from_pretrained("SEGMENTATION_MODEL_NAMEz")
96
- processor = SamProcessor.from_pretrained("SEGMENTATION_MODEL_NAME")
97
-
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  if __name__ == "__main__":
@@ -104,8 +129,7 @@ if __name__ == "__main__":
104
  image = np.array(Image.open(image_path).convert("RGB"))
105
  roi = augmenter.detect_region_of_interest(image)
106
  centered_image = augmenter.center_image_on_background(image, roi)
107
- bbox = generate_bbox(centered_image)
108
- print(bbox)
109
-
110
-
111
-
 
1
  import torch
2
  from ultralytics import YOLO
3
+ from transformers import SamModel, SamProcessor
4
  import numpy as np
5
+ from PIL import Image
6
  from config import SEGMENTATION_MODEL_NAME
7
 
8
 
 
14
  str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
15
  """
16
  if torch.cuda.is_available():
17
+ device = "cuda"
18
  elif torch.backends.mps.is_available():
19
+ device = "mps"
20
  else:
21
+ device = "cpu"
22
  return device
23
 
24
 
 
25
  class ImageAugmentation:
26
  """
27
  Class for centering an image on a white background using ROI.
 
53
  w, h = self.background_size
54
  bg = np.ones((h, w, 3), dtype=np.uint8) * 255 # White background
55
  x, y, roi_w, roi_h = roi
56
+ bg[
57
+ (h - roi_h) // 2 : (h - roi_h) // 2 + roi_h,
58
+ (w - roi_w) // 2 : (w - roi_w) // 2 + roi_w,
59
+ ] = image
60
  return bg
61
 
62
  def detect_region_of_interest(self, image):
 
71
  """
72
  # Convert image to grayscale
73
  grayscale_image = np.array(Image.fromarray(image).convert("L"))
74
+
75
  # Calculate bounding box of non-zero region
76
  bbox = Image.fromarray(grayscale_image).getbbox()
77
  return bbox
78
 
79
+
80
  def generate_bbox(image):
81
  """
82
  Generate bounding box for the input image.
 
88
  tuple: Bounding box coordinates (x, y, width, height).
89
  """
90
  # Load YOLOv5 model
91
+ model = YOLO("../models/yolov8s.pt")
92
  results = model(image)
93
  # Get bounding box coordinates
94
  bbox = results[0].boxes.xyxy.int().tolist()
95
  return bbox
96
 
97
+
98
+ def generate_mask(image):
99
+ """
100
+ Generates masks for the given image using a segmentation model.
101
+
102
+ Args:
103
+ image: The input image for which masks need to be generated.
104
+
105
+ Returns:
106
+ masks: A tensor containing the generated masks.
107
+
108
+ Raises:
109
+ None
110
+ """
111
+ model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(device=accelerator())
112
+ processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
113
+ inputs = processor(
114
+ image, input_boxes=[generate_bbox(image)], return_tensors="pt"
115
+ ).to(torch.float)
116
+ inputs.to(device=accelerator())
117
+ outputs = model(**inputs)
118
+ mask = processor.image_processor.post_process_masks(
119
+ outputs.pred_masks.cpu(),
120
+ inputs["original_sizes"].cpu(),
121
+ inputs["reshaped_input_sizes"].cpu(),
122
+ )
123
+ return mask
124
 
125
 
126
  if __name__ == "__main__":
 
129
  image = np.array(Image.open(image_path).convert("RGB"))
130
  roi = augmenter.detect_region_of_interest(image)
131
  centered_image = augmenter.center_image_on_background(image, roi)
132
+ masks = generate_mask(Image.fromarray(centered_image))
133
+ masks = np.array(masks)
134
+ mask_image = Image.fromarray(masks[0])
135
+ mask_image.save("mask.jpg")