import torch from torchvision import models, transforms from PIL import Image, ImageDraw import gradio as gr from typing import Union SEGMENTATION_MODELS = { "deeplabv3_resnet101": models.segmentation.deeplabv3_resnet101,} class ModelLoader: def __init__(self, model_dict: dict): self.model_dict = model_dict def load_model(self, model_name: str) -> torch.nn.Module: model_name_lower = model_name.lower() if model_name_lower in self.model_dict: model_class = self.model_dict[model_name_lower] model = model_class(pretrained=True) model.eval() return model else: raise ValueError(f"Model {model_name} is not supported") class Preprocessor: def __init__(self, transform: transforms.Compose = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])): self.transform = transform def preprocess(self, image: Image.Image) -> torch.Tensor: return self.transform(image).unsqueeze(0) class Postprocessor: def __init__(self): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors = torch.as_tensor([i for i in range(21)])[:, None] * palette self.colors = (colors % 255).numpy().astype("uint8") def postprocess(self, output: torch.Tensor) -> Image.Image: output_predictions = output.argmax(0) colorized_output = Image.fromarray(output_predictions.byte().cpu().numpy(), mode='P') colorized_output.putpalette(self.colors.ravel()) return colorized_output class Segmentation: def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor): self.model_loader = model_loader self.preprocessor = preprocessor self.postprocessor = postprocessor def segment(self, image: Image.Image, selected_model: str) -> Image.Image: model = self.model_loader.load_model(selected_model) input_tensor = self.preprocessor.preprocess(image) if torch.cuda.is_available(): input_tensor = input_tensor.to("cuda") model = model.to("cuda") with torch.no_grad(): output = model(input_tensor)['out'][0] return self.postprocessor.postprocess(output) class GradioApp: def __init__(self, segmentation: Segmentation): self.segmentation = segmentation def launch(self): with gr.Blocks() as demo: gr.Markdown("

Deeplabv3 Segmentation

") gr.Markdown("

Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.

") with gr.Row(): with gr.Column(): upload_image = gr.Image(type='pil', label="Upload Image") self.model_dropdown = gr.Dropdown(choices=list(SEGMENTATION_MODELS.keys()), label="Select Model") segment_button = gr.Button("Segment") with gr.Column(): output_image = gr.Image(type='pil', label="Segmented Output") segment_button.click(fn=self.segmentation.segment, inputs=[upload_image, self.model_dropdown], outputs=output_image) gr.Markdown("### Example Images") gr.Examples( examples=[ ["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"], ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"], ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"] ], inputs=upload_image, outputs=output_image, label="Click an example to use it" ) demo.launch() if __name__ == "__main__": model_loader = ModelLoader(SEGMENTATION_MODELS) preprocessor = Preprocessor() postprocessor = Postprocessor() segmentation = Segmentation(model_loader, preprocessor, postprocessor) app = GradioApp(segmentation) app.launch()