File size: 4,267 Bytes
f252cc2
 
61e99d3
f252cc2
 
 
61e99d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f252cc2
61e99d3
f252cc2
 
61e99d3
 
 
 
 
f252cc2
61e99d3
f252cc2
 
cea72b9
f252cc2
61e99d3
 
 
 
f252cc2
 
 
61e99d3
 
 
 
 
 
 
 
 
 
 
 
 
f252cc2
61e99d3
 
 
f252cc2
9d464be
61e99d3
 
 
9d464be
 
 
 
 
 
61e99d3
 
 
9d464be
61e99d3
 
9d464be
 
 
 
 
 
 
61e99d3
 
9d464be
 
 
f252cc2
cea72b9
61e99d3
 
 
 
 
9d464be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torchvision import models, transforms
from PIL import Image, ImageDraw
import gradio as gr
from typing import Union

SEGMENTATION_MODELS = {
    "deeplabv3_resnet101": models.segmentation.deeplabv3_resnet101,}

class ModelLoader:
    def __init__(self, model_dict: dict):
        self.model_dict = model_dict
    
    def load_model(self, model_name: str) -> torch.nn.Module:
        model_name_lower = model_name.lower()
        if model_name_lower in self.model_dict:
            model_class = self.model_dict[model_name_lower]
            model = model_class(pretrained=True)
            model.eval()
            return model
        else:
            raise ValueError(f"Model {model_name} is not supported")

class Preprocessor:
    def __init__(self, transform: transforms.Compose = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])):
        self.transform = transform
    
    def preprocess(self, image: Image.Image) -> torch.Tensor:
        return self.transform(image).unsqueeze(0)

class Postprocessor:
    def __init__(self):
        palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
        colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
        self.colors = (colors % 255).numpy().astype("uint8")
    
    def postprocess(self, output: torch.Tensor) -> Image.Image:
        output_predictions = output.argmax(0)
        colorized_output = Image.fromarray(output_predictions.byte().cpu().numpy(), mode='P')
        colorized_output.putpalette(self.colors.ravel())
        return colorized_output

class Segmentation:
    def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
        self.model_loader = model_loader
        self.preprocessor = preprocessor
        self.postprocessor = postprocessor
    
    def segment(self, image: Image.Image, selected_model: str) -> Image.Image:
        model = self.model_loader.load_model(selected_model)
        input_tensor = self.preprocessor.preprocess(image)

        if torch.cuda.is_available():
            input_tensor = input_tensor.to("cuda")
            model = model.to("cuda")      

        with torch.no_grad():
            output = model(input_tensor)['out'][0]
        return self.postprocessor.postprocess(output)

class GradioApp:
    def __init__(self, segmentation: Segmentation):
        self.segmentation = segmentation
    
    def launch(self):
        with gr.Blocks() as demo:
            gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
            gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
            with gr.Row():
                with gr.Column():
                    upload_image = gr.Image(type='pil', label="Upload Image")
                    self.model_dropdown = gr.Dropdown(choices=list(SEGMENTATION_MODELS.keys()), label="Select Model")
                    segment_button = gr.Button("Segment")
                with gr.Column():
                    output_image = gr.Image(type='pil', label="Segmented Output")
            segment_button.click(fn=self.segmentation.segment, inputs=[upload_image, self.model_dropdown], outputs=output_image)
            gr.Markdown("### Example Images")
            gr.Examples(
                examples=[
                    ["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"],
                    ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
                    ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
                ],
                inputs=upload_image,
                outputs=output_image,
                label="Click an example to use it"
            )
        demo.launch()

if __name__ == "__main__":
    model_loader = ModelLoader(SEGMENTATION_MODELS)
    preprocessor = Preprocessor()
    postprocessor = Postprocessor()
    segmentation = Segmentation(model_loader, preprocessor, postprocessor)
    app = GradioApp(segmentation)
    app.launch()