| import torch | |
| from torchvision import models, transforms | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| from typing import Union | |
| SEGMENTATION_MODELS = { | |
| "deeplabv3_resnet101": models.segmentation.deeplabv3_resnet101,} | |
| class ModelLoader: | |
| def __init__(self, model_dict: dict): | |
| self.model_dict = model_dict | |
| def load_model(self, model_name: str) -> torch.nn.Module: | |
| model_name_lower = model_name.lower() | |
| if model_name_lower in self.model_dict: | |
| model_class = self.model_dict[model_name_lower] | |
| model = model_class(pretrained=True) | |
| model.eval() | |
| return model | |
| else: | |
| raise ValueError(f"Model {model_name} is not supported") | |
| class Preprocessor: | |
| def __init__(self, transform: transforms.Compose = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ])): | |
| self.transform = transform | |
| def preprocess(self, image: Image.Image) -> torch.Tensor: | |
| return self.transform(image).unsqueeze(0) | |
| class Postprocessor: | |
| def __init__(self): | |
| palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | |
| colors = torch.as_tensor([i for i in range(21)])[:, None] * palette | |
| self.colors = (colors % 255).numpy().astype("uint8") | |
| def postprocess(self, output: torch.Tensor) -> Image.Image: | |
| output_predictions = output.argmax(0) | |
| colorized_output = Image.fromarray(output_predictions.byte().cpu().numpy(), mode='P') | |
| colorized_output.putpalette(self.colors.ravel()) | |
| return colorized_output | |
| class Segmentation: | |
| def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor): | |
| self.model_loader = model_loader | |
| self.preprocessor = preprocessor | |
| self.postprocessor = postprocessor | |
| def segment(self, image: Image.Image, selected_model: str) -> Image.Image: | |
| model = self.model_loader.load_model(selected_model) | |
| input_tensor = self.preprocessor.preprocess(image) | |
| if torch.cuda.is_available(): | |
| input_tensor = input_tensor.to("cuda") | |
| model = model.to("cuda") | |
| with torch.no_grad(): | |
| output = model(input_tensor)['out'][0] | |
| return self.postprocessor.postprocess(output) | |
| class GradioApp: | |
| def __init__(self, segmentation: Segmentation): | |
| self.segmentation = segmentation | |
| def launch(self): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>") | |
| gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_image = gr.Image(type='pil', label="Upload Image") | |
| self.model_dropdown = gr.Dropdown(choices=list(SEGMENTATION_MODELS.keys()), label="Select Model") | |
| segment_button = gr.Button("Segment") | |
| with gr.Column(): | |
| output_image = gr.Image(type='pil', label="Segmented Output") | |
| segment_button.click(fn=self.segmentation.segment, inputs=[upload_image, self.model_dropdown], outputs=output_image) | |
| gr.Markdown("### Example Images") | |
| gr.Examples( | |
| examples=[ | |
| ["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"], | |
| ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"], | |
| ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"] | |
| ], | |
| inputs=upload_image, | |
| outputs=output_image, | |
| label="Click an example to use it" | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| model_loader = ModelLoader(SEGMENTATION_MODELS) | |
| preprocessor = Preprocessor() | |
| postprocessor = Postprocessor() | |
| segmentation = Segmentation(model_loader, preprocessor, postprocessor) | |
| app = GradioApp(segmentation) | |
| app.launch() |