import os HOME = os.getcwd() import ultralytics from ultralytics import YOLO import matplotlib.pyplot as plt import cv2 import numpy as np import matplotlib.pyplot as plt from skimage.transform import resize from skimage.filters import threshold_otsu from PIL import Image import pdb import torchvision.transforms as T def load_yolo_model(weights_path, image_size=(512, 512)): model = YOLO(weights_path) model.img_size = image_size return model def detect_object(model, image_path, confidence=0.128): results = list(model(image_path, conf=confidence)) return results[0] if results else None def preprocess_mask(mask, target_size=(512, 512)): mask_pil = T.ToPILImage()(mask) mask_np = np.array(mask_pil) resized_mask = resize(mask_np, target_size, mode='constant') return resized_mask def generate_binary_mask(mask, target_size=(512, 512)): resized_mask = cv2.resize(mask, target_size) # Check if the mask has two channels print(resized_mask.shape) if len(resized_mask.shape) == 3: # Convert the mask to grayscale gray_mask = resized_mask[:, :, 0] else: gray_mask = resized_mask _, binary_mask = cv2.threshold(np.uint8(gray_mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) binary_mask_np = np.array(binary_mask) / 255 return binary_mask_np def overlay_mask_on_image(binary_mask, image_path): image = cv2.imread(image_path) image_np = np.array(image) binary_mask_rgb = binary_mask.astype(image_np.dtype) binary_mask_rgb = np.repeat(binary_mask_rgb[:, :, np.newaxis], 3, axis=2) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) overlayed_image = binary_mask_rgb * image_rgb pil_image = Image.fromarray(overlayed_image.astype(np.uint8)) return pil_image def save_overlayed_image(image_array, output_path): image_array.save(output_path) ##Import dataset from Roboflow Universe # from roboflow import Roboflow # rf = Roboflow(api_key="UaKV836fheN9wUixlhVq") # project = rf.workspace("lab-ply6t").project("plant-phenotypes") # dataset = project.version(1).download("yolov8") from ultralytics import YOLO train = False if train: model = YOLO("yolov8s-seg.pt") results = model.train( batch=8, device="cpu", data="C:/Users/aksha/Peeples_Lab/AgriLife_Data_Analysis-main/Segmentation/datasets/data.yaml", epochs=10, imgsz=640, ) #model validation results = model.val() # # # Create a figure with multiple subplots # fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 1 row, 3 columns # # Display the red band in the first subplot # im0 = axes[0].imshow(green.squeeze()) # axes[0].set_title('Green Band') # divider0 = make_axes_locatable(axes[0]) # cax0 = divider0.append_axes("right", size="5%", pad=0.05) # fig.colorbar(im0, cax=cax0) # axes[0].set_xticks([]) # axes[0].set_yticks([]) # # Display the green band in the second subplot # im1 = axes[1].imshow(red_edge.squeeze()) # axes[1].set_title('Red_edge Band') # divider1 = make_axes_locatable(axes[1]) # cax1 = divider1.append_axes("right", size="5%", pad=0.05) # fig.colorbar(im1, cax=cax1) # Add color bar to the second subplot # axes[1].set_xticks([]) # axes[1].set_yticks([]) # # Display the red_edge band in the third subplot # im2 = axes[2].imshow(red.squeeze()) # axes[2].set_title('Red Band') # divider2 = make_axes_locatable(axes[2]) # cax2 = divider2.append_axes("right", size="5%", pad=0.05) # fig.colorbar(im2, cax=cax2) # Add color bar to the third subplot # axes[2].set_xticks([]) # axes[2].set_yticks([]) # # Adjust spacing between subplots # plt.tight_layout() # # Display the plot # plt.show()