JaydeepR commited on
Commit
2dc9748
·
verified ·
1 Parent(s): 9b33231

Create segmentation_model

Browse files
Files changed (1) hide show
  1. segmentation_model +296 -0
segmentation_model ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import uuid
8
+ import os
9
+ import cv2
10
+ import json
11
+
12
+
13
+ input_images_dir = 'data/input_images/'
14
+ segmented_objects_dir = 'data/segmented_objects/'
15
+ os.makedirs(input_images_dir, exist_ok=True)
16
+ os.makedirs(segmented_objects_dir, exist_ok=True)
17
+
18
+ #Loading the model
19
+
20
+ def load_model():
21
+ model = maskrcnn_resnet50_fpn(pretrained=True)
22
+ # Using a different backbone
23
+ #model = maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, backbone_name='resnext50_32x4d')
24
+ model.eval()
25
+ """
26
+ We have set this to evaluation mode,
27
+ because we have loaded a pretrained model
28
+ so we must deactivate dropout layers and other
29
+ training-specific behaviors.
30
+ """
31
+ return model
32
+
33
+ model = load_model() #model initialization
34
+
35
+
36
+ def transform_image(image):
37
+ transform = T.Compose([
38
+ T.Resize((256, 256)), # Resize to match model input
39
+ T.ToTensor(), # Convert to torch tensor
40
+ T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize
41
+ ])
42
+ return transform(image).unsqueeze(0) # Add batch dimension to get [1,C,H,W] #C is channels, RGB has 3, greyscale has 1
43
+
44
+
45
+ # # Test image transformation
46
+ # image_path = "D:\multiobject.jpeg" # Replace with the path to your image
47
+ # image_tensor = transform_image(image_path)
48
+
49
+ def run_inference(model,image_tensor):
50
+ with torch.no_grad():
51
+ outputs = model(image_tensor)
52
+ return outputs
53
+
54
+ def extract_object(image, mask):
55
+ img_np = np.array(image)
56
+
57
+ # Resize mask to match image dimensions
58
+ mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)
59
+
60
+ # Create an empty image with the same dimensions as the original image
61
+ object_img = np.zeros_like(img_np)
62
+
63
+ # Apply the mask to the image
64
+ for c in range(3): # Assuming image has 3 channels (RGB)
65
+ object_img[:, :, c] = img_np[:, :, c] * mask_resized
66
+
67
+ return Image.fromarray(object_img)
68
+
69
+ # def extract_object(image, mask):
70
+ # object_img = Image.fromarray((np.array(image) * mask[:, :, None]).astype(np.uint8))
71
+ # return object_img
72
+
73
+ # Save the input image
74
+ def save_input_image(image, master_id):
75
+ input_image_path = os.path.join(input_images_dir, f'{master_id}.png')
76
+ image.save(input_image_path)
77
+ return input_image_path
78
+
79
+ # Save the extracted objects and their metadata
80
+ def save_objects_and_metadata(extracted_objects, master_id):
81
+ object_metadata = []
82
+
83
+ for i, obj_img in enumerate(extracted_objects):
84
+ object_id = str(uuid.uuid4())
85
+ object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
86
+ obj_img.save(object_image_path)
87
+
88
+ metadata = {
89
+ 'object_id': object_id,
90
+ 'master_id': master_id,
91
+ 'object_image_path': object_image_path
92
+ }
93
+ object_metadata.append(metadata)
94
+
95
+ metadata_file = os.path.join(segmented_objects_dir, f'{master_id}_metadata.json')
96
+ with open(metadata_file, 'w') as f:
97
+ json.dump(object_metadata, f, indent=4)
98
+
99
+ return object_metadata
100
+ # Run inference
101
+ #print(outputs) # This will print the model's output, including masks, labels, and scores
102
+
103
+
104
+ # def extract_objects(image, masks):
105
+ # """
106
+ # Extract objects from the segmented image using masks.
107
+
108
+ # Args:
109
+ # - image (PIL.Image): The original image.
110
+ # - masks (Tensor): Masks obtained from the segmentation model.
111
+
112
+ # Returns:
113
+ # - List of extracted objects as images.
114
+ # """
115
+ # image_np = np.array(image)
116
+ # extracted_objects = []
117
+
118
+ # for i, mask in enumerate(masks):
119
+ # # Convert mask to binary
120
+ # binary_mask = mask[0].mul(255).byte().cpu().numpy()
121
+
122
+ # # Extract object using the mask
123
+ # masked_image = cv2.bitwise_and(image_np, image_np, mask=binary_mask)
124
+
125
+ # # Find the bounding box of the object
126
+ # x, y, w, h = cv2.boundingRect(binary_mask)
127
+ # cropped_object = masked_image[y:y+h, x:x+w]
128
+
129
+ # # Convert cropped object back to PIL Image
130
+ # cropped_object_pil = Image.fromarray(cropped_object)
131
+ # extracted_objects.append(cropped_object_pil)
132
+
133
+ # return extracted_objects
134
+
135
+ # import os
136
+ # import uuid
137
+ # from PIL import Image
138
+ # import json
139
+
140
+ # # Directories to save the input images and segmented objects
141
+ # input_images_dir = 'data/input_images/'
142
+ # segmented_objects_dir = 'data/segmented_objects/'
143
+ # os.makedirs(input_images_dir, exist_ok=True)
144
+ # os.makedirs(segmented_objects_dir, exist_ok=True)
145
+
146
+ # def save_input_image(image, master_id):
147
+ # """
148
+ # Save the original input image with a unique master ID.
149
+
150
+ # Args:
151
+ # - image (PIL.Image): The original input image.
152
+ # - master_id (str): Unique ID for the original image.
153
+
154
+ # Returns:
155
+ # - str: Path to the saved input image.
156
+ # """
157
+ # input_image_path = os.path.join(input_images_dir, f'{master_id}.png')
158
+ # image.save(input_image_path)
159
+ # return input_image_path
160
+
161
+ # def save_objects_and_metadata(extracted_objects, master_id):
162
+ # """
163
+ # Save the extracted objects as images and store their metadata.
164
+
165
+ # Args:
166
+ # - extracted_objects (List[PIL.Image]): List of extracted objects as images.
167
+ # - master_id (str): Unique ID for the original image.
168
+
169
+ # Returns:
170
+ # - List of metadata dictionaries for each object.
171
+ # """
172
+ # object_metadata = []
173
+
174
+ # for i, obj_img in enumerate(extracted_objects):
175
+ # # Generate a unique ID for each object
176
+ # object_id = str(uuid.uuid4())
177
+
178
+ # # Save the object image
179
+ # object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
180
+ # obj_img.save(object_image_path)
181
+
182
+ # # Prepare metadata for the object
183
+ # metadata = {
184
+ # 'object_id': object_id,
185
+ # 'master_id': master_id,
186
+ # 'object_image_path': object_image_path
187
+ # }
188
+ # object_metadata.append(metadata)
189
+
190
+ # # Save metadata to JSON (or you can save to a database)
191
+ # metadata_file = os.path.join(segmented_objects_dir, f'{master_id}_metadata.json')
192
+ # with open(metadata_file, 'w') as f:
193
+ # json.dump(object_metadata, f, indent=4)
194
+
195
+ # return object_metadata
196
+
197
+ # # Example usage
198
+ # master_id = str(uuid.uuid4()) # Generate a unique master ID for the original image
199
+
200
+ # # Save the input image
201
+ # input_image_path = save_input_image(image, master_id)
202
+
203
+ # # Save the objects and their metadata
204
+ # metadata = save_objects_and_metadata(extracted_objects, master_id)
205
+
206
+
207
+
208
+
209
+
210
+
211
+ # import cv2
212
+ # import os
213
+ # import json
214
+ # import uuid
215
+ # import numpy as np
216
+ # from PIL import Image
217
+
218
+ # # Directories to save the segmented objects and metadata
219
+ # segmented_objects_dir = 'data/segmented_objects/'
220
+ # metadata_file = 'data/segmented_objects_metadata.json'
221
+
222
+ # # Ensure directories exist
223
+ # os.makedirs(segmented_objects_dir, exist_ok=True)
224
+
225
+ # def extract_objects(image_path, masks, master_id):
226
+ # # Load the original image
227
+ # image = Image.open(image_path)
228
+ # image_np = np.array(image)
229
+
230
+ # object_metadata = []
231
+
232
+ # for i, mask in enumerate(masks):
233
+ # # Generate a unique ID for each object
234
+ # object_id = str(uuid.uuid4())
235
+
236
+ # # Extract object using the mask
237
+ # masked_image = cv2.bitwise_and(image_np, image_np, mask=mask)
238
+
239
+ # # Find the bounding box of the object
240
+ # x, y, w, h = cv2.boundingRect(mask)
241
+ # cropped_object = masked_image[y:y+h, x:x+w]
242
+
243
+ # # Save the object image
244
+ # object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
245
+ # cv2.imwrite(object_image_path, cropped_object)
246
+
247
+ # # Save metadata
248
+ # object_metadata.append({
249
+ # 'object_id': object_id,
250
+ # 'master_id': master_id,
251
+ # 'object_image_path': object_image_path,
252
+ # 'bounding_box': (x, y, w, h)
253
+ # })
254
+
255
+ # # Save metadata to JSON
256
+ # with open(metadata_file, 'w') as f:
257
+ # json.dump(object_metadata, f, indent=4)
258
+
259
+ # return object_metadata
260
+
261
+ # # Example usage:
262
+ # # Assuming `masks` is a list of binary masks (numpy arrays) from your segmentation model
263
+ # # and `image_path` is the path to the original image
264
+ # master_id = str(uuid.uuid4())
265
+ # image_path = 'data/input_images/sample_image.png'
266
+ # masks = [...] # Replace with actual masks
267
+
268
+ # object_metadata = extract_objects(image_path, masks, master_id)
269
+
270
+
271
+ # #Extracting and saving segmented objects
272
+ # # def save_segmented_objects(image_path, outputs, output_dir='data\segmented_objects'):
273
+ # # image = Image.open(image_path).convert("RGB")
274
+ # # image_np = np.array(image)
275
+ # # masks = outputs[0]['masks']
276
+ # # scores = outputs[0]['scores']
277
+
278
+ # # if not os.path.exists(output_dir):
279
+ # # os.makedirs(output_dir)
280
+
281
+ # # for i in range(len(scores)):
282
+ # # if scores[i] > 0.5: # Confidence threshold
283
+ # # mask = masks[i].squeeze().cpu().numpy()
284
+ # # mask = np.where(mask > 0.5, 1, 0).astype(np.uint8) # Binarize mask
285
+
286
+ # # # Create a new image for the masked object
287
+ # # masked_image = np.zeros_like(image_np)
288
+ # # for c in range(3): # Apply the mask to each channel (R, G, B)
289
+ # # masked_image[:, :, c] = image_np[:, :, c] * mask
290
+
291
+ # # # Save the masked object
292
+ # # masked_image_pil = Image.fromarray(masked_image)
293
+ # # masked_image_pil.save(f"{output_dir}object_{i+1}.png")
294
+
295
+ # # # Run the function to save segmented objects
296
+ # # save_segmented_objects(image_path, outputs)