Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Define the Custom U-Net Model | |
| class UNet(nn.Module): | |
| def __init__(self, in_channels=3, out_channels=1): | |
| super(UNet, self).__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=2, stride=2) | |
| ) | |
| self.middle = nn.Sequential( | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=2, stride=2) | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2) | |
| ) | |
| self.final_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| enc = self.encoder(x) | |
| mid = self.middle(enc) | |
| dec = self.decoder(mid) | |
| output = self.final_conv(dec) | |
| return self.sigmoid(output) | |
| # Initialize Model | |
| model = UNet(in_channels=3, out_channels=1) | |
| model.eval() | |
| # Preprocess Images | |
| def preprocess_image(image): | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| return preprocess(image).unsqueeze(0) | |
| # Prediction Function | |
| def predict_flood(image_terrain, image_rainfall): | |
| image_terrain = Image.open(image_terrain).convert("RGB") | |
| image_rainfall = Image.open(image_rainfall).convert("RGB") | |
| terrain_tensor = preprocess_image(image_terrain) | |
| rainfall_tensor = preprocess_image(image_rainfall) | |
| combined_tensor = (terrain_tensor + rainfall_tensor) / 2 | |
| with torch.no_grad(): | |
| output = model(combined_tensor) | |
| output_predictions = (output.squeeze().cpu().numpy() > 0.5).astype(np.uint8) | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.imshow(output_predictions, cmap='jet', alpha=0.5) | |
| ax.set_title("Predicted Flooded Area") | |
| ax.axis("off") | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| ax.margins(0, 0) | |
| fig.canvas.draw() | |
| output_image = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| output_image = output_image.convert("RGB") | |
| return output_image | |
| # Gradio Interface | |
| def create_gradio_interface(): | |
| inputs = [ | |
| gr.Image(type="pil", label="Upload Terrain Image (RGB)"), | |
| gr.Image(type="pil", label="Upload Rainfall Image (RGB)") | |
| ] | |
| outputs = gr.Image(type="pil", label="Flood Prediction Output") | |
| gr.Interface( | |
| fn=predict_flood, | |
| inputs=inputs, | |
| outputs=outputs, | |
| live=True, | |
| description="Upload terrain and rainfall images to predict flood areas." | |
| ).launch() | |
| if __name__ == "__main__": | |
| create_gradio_interface() | |