Spaces:
Runtime error
Runtime error
import os | |
HOME = os.getcwd() | |
import ultralytics | |
from ultralytics import YOLO | |
import matplotlib.pyplot as plt | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from skimage.transform import resize | |
from skimage.filters import threshold_otsu | |
from PIL import Image | |
import pdb | |
import torchvision.transforms as T | |
def load_yolo_model(weights_path, image_size=(512, 512)): | |
model = YOLO(weights_path) | |
model.img_size = image_size | |
return model | |
def detect_object(model, image_path, confidence=0.128): | |
results = list(model(image_path, conf=confidence)) | |
return results[0] if results else None | |
def preprocess_mask(mask, target_size=(512, 512)): | |
mask_pil = T.ToPILImage()(mask) | |
mask_np = np.array(mask_pil) | |
resized_mask = resize(mask_np, target_size, mode='constant') | |
return resized_mask | |
def generate_binary_mask(mask, target_size=(512, 512)): | |
resized_mask = cv2.resize(mask, target_size) | |
# Check if the mask has two channels | |
print(resized_mask.shape) | |
if len(resized_mask.shape) == 3: | |
# Convert the mask to grayscale | |
gray_mask = resized_mask[:, :, 0] | |
else: | |
gray_mask = resized_mask | |
_, binary_mask = cv2.threshold(np.uint8(gray_mask * 255), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
binary_mask_np = np.array(binary_mask) / 255 | |
return binary_mask_np | |
def overlay_mask_on_image(binary_mask, image_path): | |
image = cv2.imread(image_path) | |
image_np = np.array(image) | |
binary_mask_rgb = binary_mask.astype(image_np.dtype) | |
binary_mask_rgb = np.repeat(binary_mask_rgb[:, :, np.newaxis], 3, axis=2) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
overlayed_image = binary_mask_rgb * image_rgb | |
pil_image = Image.fromarray(overlayed_image.astype(np.uint8)) | |
return pil_image | |
def save_overlayed_image(image_array, output_path): | |
image_array.save(output_path) | |
##Import dataset from Roboflow Universe | |
# from roboflow import Roboflow | |
# rf = Roboflow(api_key="UaKV836fheN9wUixlhVq") | |
# project = rf.workspace("lab-ply6t").project("plant-phenotypes") | |
# dataset = project.version(1).download("yolov8") | |
from ultralytics import YOLO | |
train = False | |
if train: | |
model = YOLO("yolov8s-seg.pt") | |
results = model.train( | |
batch=8, | |
device="cpu", | |
data="C:/Users/aksha/Peeples_Lab/AgriLife_Data_Analysis-main/Segmentation/datasets/data.yaml", | |
epochs=10, | |
imgsz=640, | |
) | |
#model validation | |
results = model.val() | |
# # # Create a figure with multiple subplots | |
# fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # 1 row, 3 columns | |
# # Display the red band in the first subplot | |
# im0 = axes[0].imshow(green.squeeze()) | |
# axes[0].set_title('Green Band') | |
# divider0 = make_axes_locatable(axes[0]) | |
# cax0 = divider0.append_axes("right", size="5%", pad=0.05) | |
# fig.colorbar(im0, cax=cax0) | |
# axes[0].set_xticks([]) | |
# axes[0].set_yticks([]) | |
# # Display the green band in the second subplot | |
# im1 = axes[1].imshow(red_edge.squeeze()) | |
# axes[1].set_title('Red_edge Band') | |
# divider1 = make_axes_locatable(axes[1]) | |
# cax1 = divider1.append_axes("right", size="5%", pad=0.05) | |
# fig.colorbar(im1, cax=cax1) # Add color bar to the second subplot | |
# axes[1].set_xticks([]) | |
# axes[1].set_yticks([]) | |
# # Display the red_edge band in the third subplot | |
# im2 = axes[2].imshow(red.squeeze()) | |
# axes[2].set_title('Red Band') | |
# divider2 = make_axes_locatable(axes[2]) | |
# cax2 = divider2.append_axes("right", size="5%", pad=0.05) | |
# fig.colorbar(im2, cax=cax2) # Add color bar to the third subplot | |
# axes[2].set_xticks([]) | |
# axes[2].set_yticks([]) | |
# # Adjust spacing between subplots | |
# plt.tight_layout() | |
# # Display the plot | |
# plt.show() | |