vusr's picture
Update app.py
51b9bbe verified
import gradio as gr
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import os
import time
import pdb
from Connect_Components_Preprocessing import CCA_Preprocess
from mpl_toolkits.axes_grid1 import make_axes_locatable
from Image_stitching import *
from Image_height import *
from Image_Segmentation import *
from skimage.feature import local_binary_pattern,hog
from skimage import exposure
def tile(img, d=2):
w, h = img.size
grid = product(range(0, h-h%d, d), range(0, w-w%d, d))
img_boxes = []
for i, j in grid:
box = (j, i, j+d, i+d)
img_boxes.append(box)
return img_boxes
model_weights_path = 'yolo_segmentation_model.pt'
yolo_model = load_yolo_model(model_weights_path)
def Image_Processing(filelist):
#define some variables
k = 2
color_images = []
input_images = []
image_names = []
for image_path in filelist:
image_name = image_path.split('/')[-1].split('.')[0]
image_names.append(image_name)
channel_names = ['Red (660 nm)', 'Green (580 nm)', 'Red Edge (730 nm)', 'NIR (820 nm)']
im = Image.open(image_path)
input_images.append(im)
# Divide image into 4 equal parts (separate channels)
img_size = im.size[0] // 2
img_slices = tile(im, d=img_size)
# Visualize each slice (optional, remove if not needed)
i = 0
# scaler = MinMaxScaler(feature_range=(0, 1))
img_stack = np.zeros((img_size, img_size, len(img_slices)))
for box_coords in img_slices:
# Grab image based on box_coords
temp_img = np.array(im.crop(box_coords))
# Normalize and save to composite image
img_stack[:, :, i] = temp_img
i += 1
# Grab each channel and stack to be R-G-NR-RE
red = np.expand_dims(img_stack[:, :, 1], axis=-1)
green = np.expand_dims(img_stack[:, :, 0], axis=-1)
red_edge = np.expand_dims(img_stack[:, :, 2], axis=-1)
NIR = np.expand_dims(img_stack[:, :, -1], axis=-1)
composite_img = np.concatenate((green, red_edge, red), axis=-1) * 255
normalized_img = ((composite_img - composite_img.min())*255 / (composite_img.max() - composite_img.min())).astype(np.uint8)
save_img = Image.fromarray(np.uint8(normalized_img))
cv2_image = np.array(save_img.convert('RGB'))[:, :, ::-1].copy()
color_images.append(cv2_image)
stitched_image = image_stitching(color_images)
stitched_image = Image.fromarray(np.uint8(stitched_image))
stitched_cv_image = np.array(stitched_image.convert('RGB'))[:, :, ::-1].copy()
gray_image,binary = CCA_Preprocess(stitched_cv_image,k=k)
preprocessed_img = np.repeat(np.expand_dims(binary, axis=-1), 3, axis=-1) * stitched_cv_image
normalized_preprocessed_img = (preprocessed_img - preprocessed_img.min()) / (preprocessed_img.max() - preprocessed_img.min())
normalized_preprocessed_img *= 255
normalized_preprocessed_img = normalized_preprocessed_img.astype(np.uint8)
temp_stitched_save_path = 'temp_stitched_image.png'
stitched_image.save(temp_stitched_save_path)
result = detect_object(yolo_model, temp_stitched_save_path, confidence=0.128)
if result:
mask = preprocess_mask(result.masks.data)
# Generate binary mask
binary_mask_np = generate_binary_mask(mask)
# Overlay mask on the image and save
overlayed_image = overlay_mask_on_image(binary_mask_np, temp_stitched_save_path)
overlayed_cv_image = np.array(overlayed_image.convert('RGB'))[:, :, ::-1].copy()
temp_segmented_save_path = 'temp_segmented_image.png'
overlayed_image.save(temp_segmented_save_path)
df, height = image_height(temp_segmented_save_path, 30)
radius = 1
n_points = 8 * radius
segmented_image = overlayed_cv_image
segmented_image = segmented_image.astype(np.float32)
gray_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2GRAY)
skeleton = pcv.morphology.skeletonize(mask=gray_image)
tips = pcv.morphology.find_tips(skel_img=skeleton, mask=None, label="default")
branches = pcv.morphology.find_branch_pts(skel_img=skeleton, mask=None, label="default")
tips_and_branches = np.zeros_like(skeleton)
tips_and_branches[tips > 0] = 255
tips_and_branches[branches > 0] = 128
sift = cv2.SIFT_create()
kp, des= sift.detectAndCompute(skeleton, None)
sift_image = cv2.drawKeypoints(skeleton, kp, des)
lbp = local_binary_pattern(gray_image, n_points, radius)
fd,hog_image = hog(gray_image, orientations=10, pixels_per_cell=(16, 16), cells_per_block=(1, 1), visualize=True, multichannel=False, channel_axis=-1)
hog_image_rescaled = exposure.rescale_intensity(hog_image, in_range=(0, 10))
tips = tips.astype(np.uint8)
branches = branches.astype(np.uint8)
tips_and_branches = tips_and_branches.astype(np.uint8)
sift_image = sift_image.astype(np.uint8)
lbp = lbp.astype(np.uint8)
hog_image_rescaled = hog_image_rescaled.astype(np.uint8)
input_images = [(input_images[index],image_names[index]) for index in range(len(input_images))]
color_images = [(cv2.cvtColor(color_images[index], cv2.COLOR_BGR2RGB),image_names[index]) for index in range(len(color_images))]
processed_images = [(stitched_cv_image,'Processed & Stitched Color Image'),
(normalized_preprocessed_img, 'Foreground Image by Connected Component Analysis'),
(overlayed_cv_image,'Foreground Image by Segmentation'),
(tips, 'Tips'),
(branches, 'Branches'),
(tips_and_branches, 'Tips and Branches'),
(sift_image, 'SIFT Features'),
(lbp, 'Local Binary Patterns'),
(hog_image_rescaled, 'Histogram of Oriented Gradients')]
height_text = str(round(height,2))+' cm'
return input_images,color_images,processed_images,height_text
file_input = gr.File(file_count="multiple",
label = 'Upload Raw Input Images',
show_label = True)
gallery_raw_inputs = gr.Gallery(label = 'Input Raw Plant Images',
show_label = True,
height = 512,
allow_preview = True,
preview = True)
gallery_color_images = gr.Gallery(label = 'Preprocessed Color Plant Images',
show_label = True,
height = 512,
allow_preview = True,
preview = True)
gallery_output = gr.Gallery(label = 'Plant Analysis',
show_label = True,
height = 512,
allow_preview = True,
preview = True)
textbox = gr.Textbox(label = 'Estimated Plant Hieght',
show_label = True)
iface = gr.Interface(fn = Image_Processing,
inputs = file_input,
outputs = [gallery_raw_inputs,gallery_color_images,gallery_output,textbox])
iface.launch(share=True)