ImageSegmentation2 / segmentation_model.py
JaydeepR's picture
Rename segmentation_model to segmentation_model.py
a4a65e0 verified
import torch
import torchvision.transforms as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import uuid
import os
import cv2
import json
input_images_dir = 'data/input_images/'
segmented_objects_dir = 'data/segmented_objects/'
os.makedirs(input_images_dir, exist_ok=True)
os.makedirs(segmented_objects_dir, exist_ok=True)
#Loading the model
def load_model():
model = maskrcnn_resnet50_fpn(pretrained=True)
# Using a different backbone
#model = maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, backbone_name='resnext50_32x4d')
model.eval()
"""
We have set this to evaluation mode,
because we have loaded a pretrained model
so we must deactivate dropout layers and other
training-specific behaviors.
"""
return model
model = load_model() #model initialization
def transform_image(image):
transform = T.Compose([
T.Resize((256, 256)), # Resize to match model input
T.ToTensor(), # Convert to torch tensor
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize
])
return transform(image).unsqueeze(0) # Add batch dimension to get [1,C,H,W] #C is channels, RGB has 3, greyscale has 1
# # Test image transformation
# image_path = "D:\multiobject.jpeg" # Replace with the path to your image
# image_tensor = transform_image(image_path)
def run_inference(model,image_tensor):
with torch.no_grad():
outputs = model(image_tensor)
return outputs
def extract_object(image, mask):
img_np = np.array(image)
# Resize mask to match image dimensions
mask_resized = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)
# Create an empty image with the same dimensions as the original image
object_img = np.zeros_like(img_np)
# Apply the mask to the image
for c in range(3): # Assuming image has 3 channels (RGB)
object_img[:, :, c] = img_np[:, :, c] * mask_resized
return Image.fromarray(object_img)
# def extract_object(image, mask):
# object_img = Image.fromarray((np.array(image) * mask[:, :, None]).astype(np.uint8))
# return object_img
# Save the input image
def save_input_image(image, master_id):
input_image_path = os.path.join(input_images_dir, f'{master_id}.png')
image.save(input_image_path)
return input_image_path
# Save the extracted objects and their metadata
def save_objects_and_metadata(extracted_objects, master_id):
object_metadata = []
for i, obj_img in enumerate(extracted_objects):
object_id = str(uuid.uuid4())
object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
obj_img.save(object_image_path)
metadata = {
'object_id': object_id,
'master_id': master_id,
'object_image_path': object_image_path
}
object_metadata.append(metadata)
metadata_file = os.path.join(segmented_objects_dir, f'{master_id}_metadata.json')
with open(metadata_file, 'w') as f:
json.dump(object_metadata, f, indent=4)
return object_metadata
# Run inference
#print(outputs) # This will print the model's output, including masks, labels, and scores
# def extract_objects(image, masks):
# """
# Extract objects from the segmented image using masks.
# Args:
# - image (PIL.Image): The original image.
# - masks (Tensor): Masks obtained from the segmentation model.
# Returns:
# - List of extracted objects as images.
# """
# image_np = np.array(image)
# extracted_objects = []
# for i, mask in enumerate(masks):
# # Convert mask to binary
# binary_mask = mask[0].mul(255).byte().cpu().numpy()
# # Extract object using the mask
# masked_image = cv2.bitwise_and(image_np, image_np, mask=binary_mask)
# # Find the bounding box of the object
# x, y, w, h = cv2.boundingRect(binary_mask)
# cropped_object = masked_image[y:y+h, x:x+w]
# # Convert cropped object back to PIL Image
# cropped_object_pil = Image.fromarray(cropped_object)
# extracted_objects.append(cropped_object_pil)
# return extracted_objects
# import os
# import uuid
# from PIL import Image
# import json
# # Directories to save the input images and segmented objects
# input_images_dir = 'data/input_images/'
# segmented_objects_dir = 'data/segmented_objects/'
# os.makedirs(input_images_dir, exist_ok=True)
# os.makedirs(segmented_objects_dir, exist_ok=True)
# def save_input_image(image, master_id):
# """
# Save the original input image with a unique master ID.
# Args:
# - image (PIL.Image): The original input image.
# - master_id (str): Unique ID for the original image.
# Returns:
# - str: Path to the saved input image.
# """
# input_image_path = os.path.join(input_images_dir, f'{master_id}.png')
# image.save(input_image_path)
# return input_image_path
# def save_objects_and_metadata(extracted_objects, master_id):
# """
# Save the extracted objects as images and store their metadata.
# Args:
# - extracted_objects (List[PIL.Image]): List of extracted objects as images.
# - master_id (str): Unique ID for the original image.
# Returns:
# - List of metadata dictionaries for each object.
# """
# object_metadata = []
# for i, obj_img in enumerate(extracted_objects):
# # Generate a unique ID for each object
# object_id = str(uuid.uuid4())
# # Save the object image
# object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
# obj_img.save(object_image_path)
# # Prepare metadata for the object
# metadata = {
# 'object_id': object_id,
# 'master_id': master_id,
# 'object_image_path': object_image_path
# }
# object_metadata.append(metadata)
# # Save metadata to JSON (or you can save to a database)
# metadata_file = os.path.join(segmented_objects_dir, f'{master_id}_metadata.json')
# with open(metadata_file, 'w') as f:
# json.dump(object_metadata, f, indent=4)
# return object_metadata
# # Example usage
# master_id = str(uuid.uuid4()) # Generate a unique master ID for the original image
# # Save the input image
# input_image_path = save_input_image(image, master_id)
# # Save the objects and their metadata
# metadata = save_objects_and_metadata(extracted_objects, master_id)
# import cv2
# import os
# import json
# import uuid
# import numpy as np
# from PIL import Image
# # Directories to save the segmented objects and metadata
# segmented_objects_dir = 'data/segmented_objects/'
# metadata_file = 'data/segmented_objects_metadata.json'
# # Ensure directories exist
# os.makedirs(segmented_objects_dir, exist_ok=True)
# def extract_objects(image_path, masks, master_id):
# # Load the original image
# image = Image.open(image_path)
# image_np = np.array(image)
# object_metadata = []
# for i, mask in enumerate(masks):
# # Generate a unique ID for each object
# object_id = str(uuid.uuid4())
# # Extract object using the mask
# masked_image = cv2.bitwise_and(image_np, image_np, mask=mask)
# # Find the bounding box of the object
# x, y, w, h = cv2.boundingRect(mask)
# cropped_object = masked_image[y:y+h, x:x+w]
# # Save the object image
# object_image_path = os.path.join(segmented_objects_dir, f'{object_id}.png')
# cv2.imwrite(object_image_path, cropped_object)
# # Save metadata
# object_metadata.append({
# 'object_id': object_id,
# 'master_id': master_id,
# 'object_image_path': object_image_path,
# 'bounding_box': (x, y, w, h)
# })
# # Save metadata to JSON
# with open(metadata_file, 'w') as f:
# json.dump(object_metadata, f, indent=4)
# return object_metadata
# # Example usage:
# # Assuming `masks` is a list of binary masks (numpy arrays) from your segmentation model
# # and `image_path` is the path to the original image
# master_id = str(uuid.uuid4())
# image_path = 'data/input_images/sample_image.png'
# masks = [...] # Replace with actual masks
# object_metadata = extract_objects(image_path, masks, master_id)
# #Extracting and saving segmented objects
# # def save_segmented_objects(image_path, outputs, output_dir='data\segmented_objects'):
# # image = Image.open(image_path).convert("RGB")
# # image_np = np.array(image)
# # masks = outputs[0]['masks']
# # scores = outputs[0]['scores']
# # if not os.path.exists(output_dir):
# # os.makedirs(output_dir)
# # for i in range(len(scores)):
# # if scores[i] > 0.5: # Confidence threshold
# # mask = masks[i].squeeze().cpu().numpy()
# # mask = np.where(mask > 0.5, 1, 0).astype(np.uint8) # Binarize mask
# # # Create a new image for the masked object
# # masked_image = np.zeros_like(image_np)
# # for c in range(3): # Apply the mask to each channel (R, G, B)
# # masked_image[:, :, c] = image_np[:, :, c] * mask
# # # Save the masked object
# # masked_image_pil = Image.fromarray(masked_image)
# # masked_image_pil.save(f"{output_dir}object_{i+1}.png")
# # # Run the function to save segmented objects
# # save_segmented_objects(image_path, outputs)