etemkocaaslan commited on
Commit
f252cc2
1 Parent(s): 1ec05ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from typing import Union
6
+
7
+ class Preprocessor:
8
+ def __init__(self):
9
+ """
10
+ Initialize the preprocessing transformations.
11
+ """
12
+ self.transform = transforms.Compose([
13
+ transforms.ToTensor(),
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
15
+ ])
16
+
17
+ def __call__(self, image: Image.Image) -> torch.Tensor:
18
+ """
19
+ Apply preprocessing to the input image.
20
+
21
+ :param image: Input image to be preprocessed.
22
+ :return: Preprocessed image as a tensor.
23
+ """
24
+ return self.transform(image)
25
+
26
+ class SegmentationModel:
27
+ def __init__(self):
28
+ """
29
+ Initialize and load the DeepLabV3 ResNet101 model.
30
+ """
31
+ self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
32
+ self.model.eval()
33
+ if torch.cuda.is_available():
34
+ self.model.to('cuda')
35
+
36
+ def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
37
+ """
38
+ Perform inference using the model on the input batch.
39
+
40
+ :param input_batch: Batch of preprocessed images.
41
+ :return: Model output tensor.
42
+ """
43
+ with torch.no_grad():
44
+ if torch.cuda.is_available():
45
+ input_batch = input_batch.to('cuda')
46
+ output: torch.Tensor = self.model(input_batch)['out'][0]
47
+ return output
48
+
49
+ class OutputColorizer:
50
+ def __init__(self):
51
+ """
52
+ Initialize the color palette for segmentations.
53
+ """
54
+ palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
55
+ colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette
56
+ self.colors = (colors % 255).numpy().astype("uint8")
57
+
58
+ def colorize(self, output: torch.Tensor) -> Image.Image:
59
+ """
60
+ Apply colorization to the segmentation output.
61
+
62
+ :param output: Segmentation output tensor.
63
+ :return: Colorized segmentation image.
64
+ """
65
+ colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
66
+ colorized_output.putpalette(self.colors.ravel())
67
+ return colorized_output
68
+
69
+ class Segmenter:
70
+ def __init__(self):
71
+ """
72
+ Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer.
73
+ """
74
+ self.preprocessor = Preprocessor()
75
+ self.model = SegmentationModel()
76
+ self.colorizer = OutputColorizer()
77
+
78
+ def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
79
+ """
80
+ Perform the complete segmentation process on the input image.
81
+
82
+ :param image: Input image to be segmented.
83
+ :return: Colorized segmentation image.
84
+ """
85
+ input_image: Image.Image = image.convert("RGB")
86
+ input_tensor: torch.Tensor = self.preprocessor(input_image)
87
+ input_batch: torch.Tensor = input_tensor.unsqueeze(0)
88
+ output: torch.Tensor = self.model.predict(input_batch)
89
+ output_predictions: torch.Tensor = output.argmax(0)
90
+ return self.colorizer.colorize(output_predictions)
91
+
92
+ segmenter = Segmenter()
93
+
94
+ interface = gr.Interface(
95
+ fn=segmenter.segment,
96
+ inputs=gr.Image(type="pil"),
97
+ outputs=gr.Image(type="pil"),
98
+ title="Deeplabv3 Segmentation",
99
+ description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ interface.launch()