WeCanopy / generate_tree_images /generate_tree_images.py
moritzmoritzmoritzmoritz
added the python file
97605b3
raw
history blame
6.07 kB
import os
import rasterio
import geopandas as gpd
from shapely.geometry import box
from rasterio.mask import mask
from PIL import Image
import numpy as np
import warnings
from rasterio.errors import NodataShadowWarning
import sys
warnings.filterwarnings("ignore", category=NodataShadowWarning)
def cut_trees(output_dir, geojson_path, tif_path):
# create output directory if it doesnt exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load the GeoDataFrame
gdf = gpd.read_file(geojson_path)
# Clear the terminal screen
os.system('cls' if os.name == 'nt' else 'clear')
# Open the .tif file
with rasterio.open(tif_path) as src:
# Get the bounds of the .tif image
tif_bounds = box(*src.bounds)
# Get the CRS (Coordinate Reference System) of the .tif image
tif_crs = src.crs
# Reproject the GeoDataFrame to the CRS of the .tif file
gdf = gdf.to_crs(tif_crs)
# Loop through each polygon in the GeoDataFrame
N = len(gdf)
n = int(N/10)
image_counter = 0
for idx, row in gdf.iterrows():
if idx % n == 0:
progress = f"{round(idx/N*100)} % complete --> {idx}/{N}"
sys.stdout.write('\r' + progress)
sys.stdout.flush()
# Extract the geometry (polygon)
geom = row['geometry']
name = row['id']
# Check if the polygon intersects the image bounds
if geom.intersects(tif_bounds):
# Create a mask for the current polygon
out_image, out_transform = mask(src, [geom], crop=True)
# Convert the masked image to a numpy array
out_image = out_image.transpose(1, 2, 0) # rearrange dimensions for PIL (H, W, C)
# Ensure the array is not empty
if out_image.size == 0:
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an empty image and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
continue
# Remove the zero-padded areas (optional)
mask_array = (out_image[:, :, 0] != src.nodata)
non_zero_rows = np.any(mask_array, axis=1)
non_zero_cols = np.any(mask_array, axis=0)
# Ensure there are non-zero rows and columns
if not np.any(non_zero_rows) or not np.any(non_zero_cols):
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} resulted in an invalid image area and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
continue
out_image = out_image[non_zero_rows][:, non_zero_cols]
# Convert to a PIL Image and save as PNG
out_image = Image.fromarray(out_image.astype(np.uint8)) # Ensure correct type for PIL
output_path = os.path.join(output_dir, f'tree_{name}.png')
out_image.save(output_path)
image_counter += 1
else:
message = f"{round(idx/N*100)} % complete --> {idx}/{N} | Polygon {idx} is outside the image bounds and will be skipped."
sys.stdout.write('\r' + message)
sys.stdout.flush()
print(f'\n {image_counter}/{N} Tree images have been successfully saved in the "detected_trees" folder.')
def resize_images(input_folder, output_folder, target_size):
# Create the output folder if it doesn't exist
if not os.path.exists(output_folder):
os.makedirs(output_folder)
counter = 0
# Loop through all files in the input folder
for filename in os.listdir(input_folder):
if filename.endswith('.png'): # Check for PNG files
# Open image
with Image.open(os.path.join(input_folder, filename)) as img:
# Resize image while preserving aspect ratio
img.thumbnail(target_size, Image.LANCZOS)
# Calculate paste position to center image in canvas
paste_pos = ((target_size[0] - img.size[0]) // 2, (target_size[1] - img.size[1]) // 2)
# Create a new blank canvas with the target size and black background
new_img = Image.new("RGBA", target_size, (0, 0, 0, 255))
# Paste resized image onto the canvas
new_img.paste(img, paste_pos, img)
# Convert to RGB to remove transparency by merging with black background
new_img = new_img.convert("RGB")
# Save resized image to output folder
new_img.save(os.path.join(output_folder, filename))
counter += 1
# Display the counter
if counter % 50 == 0:
message = f"Processed {counter} images"
print(message, end='\r')
# Final message after processing all images
print(f"Processed a total of {counter} images.")
# THIS IS THE FUNCTION TO IMPORT
def generate_tree_images(geojson_path, tif_path, target_size = (224, 224)):
"""
INPUT: geojson path, tif_path that contain the trees, optional target_size of the resulting images
RETURNS: nothing
Action: It creates two folders: + "detected trees" --> the cut tree images
+ "tree_images" --> the processed cut tree images, ready to use for species recognition
"""
# Set input and output folders
folder_cut_trees = "detected_trees"
folder_finished_images = "tree_images"
# Set target size (width, height)
cut_trees(geojson_path = geojson_path, tif_path = tif_path, output_dir = folder_cut_trees)
resize_images(input_folder = folder_cut_trees, output_folder = folder_finished_images, target_size = target_size)