Plant_Phenotyping_Analysis / Image_Segmentation.py
vusr's picture
Upload 2 files
3dce7c6 verified
import os
HOME = os.getcwd()
import ultralytics
from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage.filters import threshold_otsu
from PIL import Image
import pdb
import torchvision.transforms as T
def load_yolo_model(weights_path, image_size=(512, 512)):
model = YOLO(weights_path)
model.img_size = image_size
return model
def detect_object(model, image_path, confidence=0.128):
results = list(model(image_path, conf=confidence))
return results[0] if results else None
def preprocess_mask(mask, target_size=(512, 512)):
mask_pil = T.ToPILImage()(mask)
mask_np = np.array(mask_pil)
resized_mask = resize(mask_np, target_size, mode='constant')
return resized_mask
def generate_binary_mask(mask, target_size=(512, 512)):
resized_mask = cv2.resize(mask, target_size)
# Check if the mask has two channels
print(resized_mask.shape)
if len(resized_mask.shape) == 3:
# Convert the mask to grayscale
gray_mask = resized_mask[:, :, 0]
else:
gray_mask = resized_mask
_, binary_mask = cv2.threshold(np.uint8(gray_mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
binary_mask_np = np.array(binary_mask) / 255
return binary_mask_np
def overlay_mask_on_image(binary_mask, image_path):
image = cv2.imread(image_path)
image_np = np.array(image)
binary_mask_rgb = binary_mask.astype(image_np.dtype)
binary_mask_rgb = np.repeat(binary_mask_rgb[:, :, np.newaxis], 3, axis=2)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
overlayed_image = binary_mask_rgb * image_rgb
pil_image = Image.fromarray(overlayed_image.astype(np.uint8))
return pil_image
def save_overlayed_image(image_array, output_path):
image_array.save(output_path)
##Import dataset from Roboflow Universe
# from roboflow import Roboflow
# rf = Roboflow(api_key="UaKV836fheN9wUixlhVq")
# project = rf.workspace("lab-ply6t").project("plant-phenotypes")
# dataset = project.version(1).download("yolov8")
from ultralytics import YOLO
train = False
if train:
model = YOLO("yolov8s-seg.pt")
results = model.train(
batch=8,
device="cpu",
data="C:/Users/aksha/Peeples_Lab/AgriLife_Data_Analysis-main/Segmentation/datasets/data.yaml",
epochs=10,
imgsz=640,
)
#model validation
results = model.val()
# # # Create a figure with multiple subplots
# fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 1 row, 3 columns
# # Display the red band in the first subplot
# im0 = axes[0].imshow(green.squeeze())
# axes[0].set_title('Green Band')
# divider0 = make_axes_locatable(axes[0])
# cax0 = divider0.append_axes("right", size="5%", pad=0.05)
# fig.colorbar(im0, cax=cax0)
# axes[0].set_xticks([])
# axes[0].set_yticks([])
# # Display the green band in the second subplot
# im1 = axes[1].imshow(red_edge.squeeze())
# axes[1].set_title('Red_edge Band')
# divider1 = make_axes_locatable(axes[1])
# cax1 = divider1.append_axes("right", size="5%", pad=0.05)
# fig.colorbar(im1, cax=cax1) # Add color bar to the second subplot
# axes[1].set_xticks([])
# axes[1].set_yticks([])
# # Display the red_edge band in the third subplot
# im2 = axes[2].imshow(red.squeeze())
# axes[2].set_title('Red Band')
# divider2 = make_axes_locatable(axes[2])
# cax2 = divider2.append_axes("right", size="5%", pad=0.05)
# fig.colorbar(im2, cax=cax2) # Add color bar to the third subplot
# axes[2].set_xticks([])
# axes[2].set_yticks([])
# # Adjust spacing between subplots
# plt.tight_layout()
# # Display the plot
# plt.show()