Vikramjeet Singh commited on
Commit
9e3aee1
2 Parent(s): a2d3846 d1a4430

Merge pull request #24 from VikramxD/v2

Browse files

V2

Former-commit-id: 62e8d7135a648288fb058a3c574167c155d48711

scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/config.py CHANGED
@@ -6,7 +6,8 @@ DATASET_NAME= "hahminlew/kream-product-blip-captions"
6
  PROJECT_NAME = "Product Photography"
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
- SEGMENTATION_MODEL_NAME = "facebook/sam-vit-huge"
 
10
 
11
 
12
 
 
6
  PROJECT_NAME = "Product Photography"
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
+ SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
10
+ DETECTION_MODEL_NAME = "yolov8s"
11
 
12
 
13
 
scripts/extended_image.png ADDED
scripts/mask.png ADDED
scripts/utils.py CHANGED
@@ -2,10 +2,11 @@ import torch
2
  from ultralytics import YOLO
3
  from transformers import SamModel, SamProcessor
4
  import numpy as np
5
- from PIL import Image
6
- from config import SEGMENTATION_MODEL_NAME
7
- import cv2
8
- import matplotlib.pyplot as plt
 
9
 
10
  def accelerator():
11
  """
@@ -21,7 +22,6 @@ def accelerator():
21
  else:
22
  return "cpu"
23
 
24
-
25
  class ImageAugmentation:
26
  """
27
  Class for centering an image on a white background using ROI.
@@ -32,119 +32,71 @@ class ImageAugmentation:
32
  roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
33
  """
34
 
35
- def __init__(self, target_width, target_height, roi_scale=0.5):
36
- """
37
- Initialize ImageAugmentation class.
38
-
39
- Args:
40
- target_width (int): Desired width of the extended image.
41
- target_height (int): Desired height of the extended image.
42
- roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
43
- """
44
  self.target_width = target_width
45
  self.target_height = target_height
46
  self.roi_scale = roi_scale
47
 
48
- def extend_image(self, image_path):
49
  """
50
- Extends the given image to the specified target dimensions while maintaining the aspect ratio of the original image.
51
- The image is centered based on the detected region of interest (ROI).
52
-
53
- Args:
54
- image_path (str): The path to the image file.
55
-
56
- Returns:
57
- PIL.Image.Image: The extended image with the specified dimensions.
58
  """
59
- # Open the original image
60
- original_image = cv2.imread(image_path)
61
-
62
- # Convert the image to grayscale for better edge detection
63
- gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
64
-
65
- # Perform edge detection to find contours
66
- edges = cv2.Canny(gray_image, 50, 150)
67
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
-
69
- # Find the largest contour (assumed to be the ROI)
70
- largest_contour = max(contours, key=cv2.contourArea)
71
-
72
- # Get the bounding box of the largest contour
73
- x, y, w, h = cv2.boundingRect(largest_contour)
74
-
75
- # Calculate the center of the bounding box
76
- roi_center_x = x + w // 2
77
- roi_center_y = y + h // 2
78
-
79
- # Calculate the top-left coordinates of the ROI
80
- roi_x = max(0, roi_center_x - self.target_width // 2)
81
- roi_y = max(0, roi_center_y - self.target_height // 2)
82
-
83
- # Crop the ROI from the original image
84
- roi = original_image[roi_y:roi_y+self.target_height, roi_x:roi_x+self.target_width]
85
-
86
- # Create a new white background image with the target dimensions
87
- extended_image = np.ones((self.target_height, self.target_width, 3), dtype=np.uint8) * 255
88
-
89
- # Calculate the paste position for centering the ROI
90
- paste_x = (self.target_width - roi.shape[1]) // 2
91
- paste_y = (self.target_height - roi.shape[0]) // 2
92
-
93
- # Paste the ROI onto the white background
94
- extended_image[paste_y:paste_y+roi.shape[0], paste_x:paste_x+roi.shape[1]] = roi
95
-
96
- return Image.fromarray(cv2.cvtColor(extended_image, cv2.COLOR_BGR2RGB))
97
-
98
-
99
- def generate_bbox(self, image):
100
  """
101
- Generate bounding box for the input image.
102
 
103
  Args:
104
- image: The input image.
105
 
106
  Returns:
107
- list: Bounding box coordinates [x_min, y_min, x_max, y_max].
108
  """
109
- model = YOLO("yolov8s.pt")
110
- results = model(image)
111
- bbox = results[0].boxes.xyxy.tolist()
112
- return bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- def generate_mask(self, image, bbox):
115
- """
116
- Generates masks for the given image using a segmentation model.
117
 
118
- Args:
119
- image: The input image for which masks need to be generated.
120
- bbox: Bounding box coordinates [x_min, y_min, x_max, y_max].
121
 
122
- Returns:
123
- numpy.ndarray: The generated mask.
 
124
  """
125
- model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(device=accelerator())
126
- processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
127
-
128
- # Ensure bbox is in the correct format
129
- bbox_list = [bbox] # Convert bbox to list of lists
130
 
131
- # Pass bbox as a list of lists to SamProcessor
132
- inputs = processor(image, input_boxes=bbox_list, return_tensors="pt").to(device=accelerator())
133
- with torch.no_grad():
134
- outputs = model(**inputs)
135
- masks = processor.image_processor.post_process_masks(
136
- outputs.pred_masks,
137
- inputs["original_sizes"],
138
- inputs["reshaped_input_sizes"],
139
- )
140
 
141
- return masks[0].cpu().numpy()
 
142
 
143
  if __name__ == "__main__":
144
- augmenter = ImageAugmentation(target_width=1920, target_height=1080, roi_scale=0.3)
145
- image_path = "/home/product_diffusion_api/sample_data/example1.jpg"
146
- extended_image = augmenter.extend_image(image_path)
147
- bbox = augmenter.generate_bbox(extended_image)
148
- mask = augmenter.generate_mask(extended_image, bbox)
149
- plt.imsave('mask.jpg', mask)
150
- #Image.fromarray(mask).save("centered_image_with_mask.jpg")
 
 
2
  from ultralytics import YOLO
3
  from transformers import SamModel, SamProcessor
4
  import numpy as np
5
+ from PIL import Image, ImageOps
6
+ from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
7
+ from diffusers.utils import load_image
8
+
9
+
10
 
11
  def accelerator():
12
  """
 
22
  else:
23
  return "cpu"
24
 
 
25
  class ImageAugmentation:
26
  """
27
  Class for centering an image on a white background using ROI.
 
32
  roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
33
  """
34
 
35
+ def __init__(self, target_width, target_height, roi_scale=0.6):
 
 
 
 
 
 
 
 
36
  self.target_width = target_width
37
  self.target_height = target_height
38
  self.roi_scale = roi_scale
39
 
40
+ def extend_image(self, image: Image) -> Image:
41
  """
42
+ Extends an image to fit within the specified target dimensions while maintaining the aspect ratio.
 
 
 
 
 
 
 
43
  """
44
+ original_width, original_height = image.size
45
+ scale = min(self.target_width / original_width, self.target_height / original_height)
46
+ new_width = int(original_width * scale * self.roi_scale)
47
+ new_height = int(original_height * scale * self.roi_scale)
48
+ resized_image = image.resize((new_width, new_height))
49
+ extended_image = Image.new("RGB", (self.target_width, self.target_height), "white")
50
+ paste_x = (self.target_width - new_width) // 2
51
+ paste_y = (self.target_height - new_height) // 2
52
+ extended_image.paste(resized_image, (paste_x, paste_y))
53
+ return extended_image
54
+
55
+ def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  """
57
+ Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
58
 
59
  Args:
60
+ image_path (str): The path to the input image.
61
 
62
  Returns:
63
+ numpy.ndarray: The generated mask as a NumPy array.
64
  """
65
+
66
+ yolo = YOLO(detection_model)
67
+ processor = SamProcessor.from_pretrained(segmentation_model)
68
+ model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
69
+ results = yolo(image)
70
+ bboxes = results[0].boxes.xyxy.tolist()
71
+ input_boxes = [[[bboxes[0]]]]
72
+ inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
73
+ with torch.no_grad():
74
+ outputs = model(**inputs)
75
+ mask = processor.image_processor.post_process_masks(
76
+ outputs.pred_masks.cpu(),
77
+ inputs["original_sizes"].cpu(),
78
+ inputs["reshaped_input_sizes"].cpu()
79
+ )[0][0][0].numpy()
80
+ mask_image = Image.fromarray(mask)
81
+ return mask_image
82
 
 
 
 
83
 
 
 
 
84
 
85
+ def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
86
+ """
87
+ Inverts the given mask image.
88
  """
 
 
 
 
 
89
 
 
 
 
 
 
 
 
 
 
90
 
91
+ inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
92
+ return inverted_mask_pil
93
 
94
  if __name__ == "__main__":
95
+ augmenter = ImageAugmentation(target_width=2560, target_height=1440, roi_scale=0.7)
96
+ image_path = "/home/product_diffusion_api/sample_data/example3.jpg"
97
+ image = Image.open(image_path)
98
+ extended_image = augmenter.extend_image(image)
99
+ mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
100
+ inverted_mask_image = augmenter.invert_mask(mask)
101
+ mask.save("mask.jpg")
102
+ inverted_mask_image.save("inverted_mask.jpg")