SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/prediction-checkpoint.py
import os | |
import cv2 | |
import time | |
import torch | |
import imageio | |
import tifffile | |
import numpy as np | |
import slidingwindow | |
import rasterio as rio | |
import geopandas as gpd | |
from shapely.geometry import Polygon | |
from rasterio import mask as riomask | |
from torch.utils.data import DataLoader | |
from SemanticModel.visualization import generate_color_mapping | |
from SemanticModel.image_preprocessing import get_validation_augmentations | |
from SemanticModel.data_loader import InferenceDataset, StreamingDataset | |
from SemanticModel.utilities import calc_image_size, convert_coordinates | |
class PredictionPipeline: | |
def __init__(self, model_config, device=None): | |
self.config = model_config | |
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes | |
self.colors = generate_color_mapping(len(self.classes)) | |
self.model = model_config.model.to(self.device) | |
self.model.eval() | |
def _preprocess_image(self, image_path, target_size=None): | |
"""Preprocesses single image for prediction.""" | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
height, width = image.shape[:2] | |
target_size = target_size or max(height, width) | |
test_height, test_width = calc_image_size(image, target_size) | |
augmentation = get_validation_augmentations(test_width, test_height) | |
image = augmentation(image=image)['image'] | |
image = self.config.preprocessing(image=image)['image'] | |
return image, (height, width) | |
def predict_single_image(self, image_path, target_size=None, output_dir=None, | |
format='integer', save_output=True): | |
"""Generates prediction for a single image.""" | |
image, original_dims = self._preprocess_image(image_path, target_size) | |
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = self.model.predict(x_tensor) | |
if self.config.n_classes > 1: | |
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) | |
else: | |
prediction = prediction.squeeze().cpu().numpy().round() | |
# Resize to original dimensions if needed | |
if prediction.shape[:2] != original_dims: | |
prediction = cv2.resize(prediction, original_dims[::-1], | |
interpolation=cv2.INTER_NEAREST) | |
prediction = self._format_prediction(prediction, format) | |
if save_output: | |
self._save_prediction(prediction, image_path, output_dir, format) | |
return prediction | |
def predict_directory(self, input_dir, target_size=None, output_dir=None, | |
fixed_size=True, format='integer'): | |
"""Generates predictions for all images in directory.""" | |
output_dir = output_dir or os.path.join(input_dir, 'predictions') | |
os.makedirs(output_dir, exist_ok=True) | |
dataset = InferenceDataset( | |
input_dir, | |
classes=self.classes, | |
augmentation=get_validation_augmentations( | |
target_size, target_size, fixed_size=fixed_size | |
) if target_size else None, | |
preprocessing=self.config.preprocessing | |
) | |
total_images = len(dataset) | |
start_time = time.time() | |
for idx in range(total_images): | |
if (idx + 1) % 10 == 0 or idx == total_images - 1: | |
elapsed = time.time() - start_time | |
print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s', | |
end='') | |
image, height, width = dataset[idx] | |
filename = dataset.filenames[idx] | |
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = self.model.predict(x_tensor) | |
if self.config.n_classes > 1: | |
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) | |
else: | |
prediction = prediction.squeeze().cpu().numpy().round() | |
if prediction.shape != (height, width): | |
prediction = cv2.resize(prediction, (width, height), | |
interpolation=cv2.INTER_NEAREST) | |
prediction = self._format_prediction(prediction, format) | |
self._save_prediction(prediction, filename, output_dir, format) | |
print(f'\nPredictions saved to: {output_dir}') | |
return output_dir | |
def predict_raster(self, raster_path, tile_size=1024, overlap=0.175, | |
boundary_path=None, output_path=None, format='integer'): | |
"""Processes large raster images using tiling approach.""" | |
print('Loading raster...') | |
with rio.open(raster_path) as src: | |
raster = src.read() | |
raster = np.moveaxis(raster, 0, 2)[:,:,:3] | |
profile = src.profile | |
transform = src.transform | |
if boundary_path: | |
boundary = gpd.read_file(boundary_path) | |
boundary = boundary.to_crs(profile['crs']) | |
boundary_geom = boundary.iloc[0].geometry | |
tiles = slidingwindow.generate( | |
raster, | |
slidingwindow.DimOrder.HeightWidthChannel, | |
tile_size, | |
overlap | |
) | |
pred_raster = np.zeros_like(raster[:,:,0], dtype='uint8') | |
confidence = np.zeros_like(pred_raster, dtype=np.float32) | |
aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False) | |
for idx, tile in enumerate(tiles): | |
if (idx + 1) % 10 == 0 or idx == len(tiles) - 1: | |
print(f'\rProcessed {idx+1}/{len(tiles)} tiles', end='') | |
bounds = tile.indices() | |
tile_image = raster[bounds[0], bounds[1]] | |
if boundary_path: | |
corners = [ | |
convert_coordinates(transform, bounds[1].start, bounds[0].start), | |
convert_coordinates(transform, bounds[1].stop, bounds[0].start), | |
convert_coordinates(transform, bounds[1].stop, bounds[0].stop), | |
convert_coordinates(transform, bounds[1].start, bounds[0].stop) | |
] | |
if not Polygon(corners).intersects(boundary_geom): | |
continue | |
processed = aug(image=tile_image)['image'] | |
processed = self.config.preprocessing(image=processed)['image'] | |
x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = self.model.predict(x_tensor) | |
prediction = prediction.squeeze().cpu().numpy() | |
if self.config.n_classes > 1: | |
tile_pred = np.argmax(prediction, axis=0) | |
tile_conf = np.max(prediction, axis=0) | |
else: | |
tile_conf = np.abs(prediction - 0.5) | |
tile_pred = prediction.round() | |
if tile_pred.shape != tile_image.shape[:2]: | |
tile_pred = cv2.resize(tile_pred, tile_image.shape[:2][::-1], | |
interpolation=cv2.INTER_NEAREST) | |
tile_conf = cv2.resize(tile_conf, tile_image.shape[:2][::-1], | |
interpolation=cv2.INTER_LINEAR) | |
# Update prediction and confidence maps | |
existing_conf = confidence[bounds[0], bounds[1]] | |
existing_pred = pred_raster[bounds[0], bounds[1]] | |
mask = existing_conf < tile_conf | |
existing_pred[mask] = tile_pred[mask] | |
existing_conf[mask] = tile_conf[mask] | |
pred_raster[bounds[0], bounds[1]] = existing_pred | |
confidence[bounds[0], bounds[1]] = existing_conf | |
pred_raster = self._format_prediction(pred_raster, format) | |
if output_path or boundary_path: | |
self._save_raster_prediction( | |
pred_raster, raster_path, output_path, | |
profile, boundary_geom if boundary_path else None | |
) | |
return pred_raster, profile | |
def _format_prediction(self, prediction, format): | |
"""Formats prediction according to specified output type.""" | |
if format == 'integer': | |
return prediction.astype('uint8') | |
elif format == 'color': | |
return self._apply_color_mapping(prediction) | |
else: | |
raise ValueError(f"Unsupported format: {format}") | |
def _save_prediction(self, prediction, source_path, output_dir, format): | |
"""Saves prediction to disk.""" | |
filename = os.path.splitext(os.path.basename(source_path))[0] | |
output_path = os.path.join(output_dir, f"{filename}_pred.png") | |
cv2.imwrite(output_path, prediction) | |
def _save_raster_prediction(self, prediction, source_path, output_path, | |
profile, boundary=None): | |
"""Saves raster prediction with geospatial information.""" | |
output_path = output_path or source_path.replace( | |
os.path.splitext(source_path)[1], '_predicted.tif' | |
) | |
profile.update( | |
dtype='uint8', | |
count=3 if prediction.ndim == 3 else 1 | |
) | |
with rio.open(output_path, 'w', **profile) as dst: | |
if prediction.ndim == 3: | |
for i in range(3): | |
dst.write(prediction[:,:,i], i+1) | |
else: | |
dst.write(prediction, 1) | |
if boundary: | |
with rio.open(output_path) as src: | |
cropped, transform = riomask.mask(src, [boundary], crop=True) | |
profile.update( | |
height=cropped.shape[1], | |
width=cropped.shape[2], | |
transform=transform | |
) | |
os.remove(output_path) | |
with rio.open(output_path, 'w', **profile) as dst: | |
dst.write(cropped) | |
print(f'\nPrediction saved to: {output_path}') | |
def predict_video_frames(self, input_dir, target_size=None, output_dir=None): | |
"""Processes video frames with specialized visualization.""" | |
output_dir = output_dir or os.path.join(input_dir, 'predictions') | |
os.makedirs(output_dir, exist_ok=True) | |
dataset = StreamingDataset( | |
input_dir, | |
classes=self.classes, | |
augmentation=get_validation_augmentations( | |
target_size, target_size | |
) if target_size else None, | |
preprocessing=self.config.preprocessing | |
) | |
image = cv2.imread(dataset.image_paths[0]) | |
height, width = image.shape[:2] | |
white = 255 * np.ones((height, width)) | |
black = np.zeros_like(white) | |
red = np.dstack((white, black, black)) | |
blue = np.dstack((black, black, white)) | |
# Pre-compute rotated versions | |
rotated_red = np.rot90(red) | |
rotated_blue = np.rot90(blue) | |
total_frames = len(dataset) | |
start_time = time.time() | |
for idx in range(total_frames): | |
if (idx + 1) % 10 == 0 or idx == total_frames - 1: | |
elapsed = time.time() - start_time | |
print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='') | |
frame, height, width = dataset[idx] | |
filename = dataset.filenames[idx] | |
x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = self.model.predict(x_tensor) | |
if self.config.n_classes > 1: | |
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) | |
masks = [prediction == i for i in range(1, self.config.n_classes)] | |
else: | |
prediction = prediction.squeeze().cpu().numpy().round() | |
masks = [prediction == 1] | |
if prediction.shape != (height, width): | |
prediction = cv2.resize(prediction, (width, height), | |
interpolation=cv2.INTER_NEAREST) | |
original = cv2.imread(os.path.join(input_dir, filename)) | |
original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB) | |
try: | |
for i, mask in enumerate(masks): | |
color = red if i == 0 else blue | |
rotated_color = rotated_red if i == 0 else rotated_blue | |
try: | |
original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:] | |
except: | |
original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:] | |
except: | |
print(f"\nWarning: Error processing frame {filename}") | |
continue | |
output_path = os.path.join(output_dir, filename) | |
imageio.imwrite(output_path, original, quality=100) | |
print(f'\nProcessed frames saved to: {output_dir}') | |
return output_dir | |
def _apply_color_mapping(self, prediction): | |
"""Applies color mapping to prediction.""" | |
height, width = prediction.shape | |
colored = np.zeros((height, width, 3), dtype='uint8') | |
for i, class_name in enumerate(self.classes): | |
if class_name.lower() == 'background': | |
continue | |
color = self.colors[i] | |
colored[prediction == i] = color | |
return colored |