etemkocaaslan commited on
Commit
61e99d3
1 Parent(s): 6ac4bd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -56
app.py CHANGED
@@ -1,80 +1,82 @@
1
  import torch
2
  from torchvision import models, transforms
3
- from PIL import Image
4
  import gradio as gr
5
  from typing import Union
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class Preprocessor:
8
- def __init__(self):
9
- self.transform = transforms.Compose([
10
  transforms.ToTensor(),
11
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
12
- ])
13
-
14
- def __call__(self, image: Image.Image) -> torch.Tensor:
15
- return self.transform(image)
16
-
17
- class SegmentationModel:
18
- def __init__(self):
19
- self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
20
- self.model.eval()
21
- if torch.cuda.is_available():
22
- self.model.to('cuda')
23
-
24
- def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
25
- with torch.no_grad():
26
- if torch.cuda.is_available():
27
- input_batch = input_batch.to('cuda')
28
- output: torch.Tensor = self.model(input_batch)['out'][0]
29
- return output
30
 
31
- class OutputColorizer:
32
  def __init__(self):
33
  palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
34
  colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
35
  self.colors = (colors % 255).numpy().astype("uint8")
36
-
37
- def colorize(self, output: torch.Tensor) -> Image.Image:
38
- colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
 
39
  colorized_output.putpalette(self.colors.ravel())
40
  return colorized_output
41
 
42
- class Segmenter:
43
- def __init__(self):
44
- self.preprocessor = Preprocessor()
45
- self.model = SegmentationModel()
46
- self.colorizer = OutputColorizer()
 
 
 
 
 
 
 
 
47
 
48
- def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
49
- input_image: Image.Image = image.convert("RGB")
50
- input_tensor: torch.Tensor = self.preprocessor(input_image)
51
- input_batch: torch.Tensor = input_tensor.unsqueeze(0)
52
- output: torch.Tensor = self.model.predict(input_batch)
53
- output_predictions: torch.Tensor = output.argmax(0)
54
- return self.colorizer.colorize(output_predictions)
55
 
56
  class GradioApp:
57
- def __init__(self, segmenter: Segmenter):
58
- self.segmenter = segmenter
59
-
60
  def launch(self):
61
  with gr.Blocks() as demo:
62
  gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
63
  gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
64
- gr.Markdown("""
65
- ### Model Information
66
- **DeepLabv3 with ResNet101** is a convolutional neural network model designed for semantic image segmentation.
67
- It utilizes atrous convolution to capture multi-scale context by using different atrous rates.
68
- """)
69
  with gr.Row():
70
  with gr.Column():
71
- image_input = gr.Image(type='pil', label="Input Image", show_label=False)
 
 
72
  with gr.Column():
73
- image_output = gr.Image(type='pil', label="Segmented Output", show_label=False)
74
-
75
- button = gr.Button("Segment")
76
- button.click(fn=self.segmenter.segment, inputs=image_input, outputs=image_output)
77
-
78
  gr.Markdown("### Example Images")
79
  gr.Examples(
80
  examples=[
@@ -82,14 +84,16 @@ class GradioApp:
82
  ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
83
  ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
84
  ],
85
- inputs=image_input,
86
- outputs=image_output,
87
  label="Click an example to use it"
88
  )
89
-
90
  demo.launch()
91
 
92
  if __name__ == "__main__":
93
- segmenter = Segmenter()
94
- app = GradioApp(segmenter)
 
 
 
95
  app.launch()
 
1
  import torch
2
  from torchvision import models, transforms
3
+ from PIL import Image, ImageDraw
4
  import gradio as gr
5
  from typing import Union
6
 
7
+ SEGMENTATION_MODELS = {
8
+ "deeplabv3_resnet101": models.segmentation.deeplabv3_resnet101,}
9
+
10
+ class ModelLoader:
11
+ def __init__(self, model_dict: dict):
12
+ self.model_dict = model_dict
13
+
14
+ def load_model(self, model_name: str) -> torch.nn.Module:
15
+ model_name_lower = model_name.lower()
16
+ if model_name_lower in self.model_dict:
17
+ model_class = self.model_dict[model_name_lower]
18
+ model = model_class(pretrained=True)
19
+ model.eval()
20
+ return model
21
+ else:
22
+ raise ValueError(f"Model {model_name} is not supported")
23
+
24
  class Preprocessor:
25
+ def __init__(self, transform: transforms.Compose = transforms.Compose([
 
26
  transforms.ToTensor(),
27
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
+ ])):
29
+ self.transform = transform
30
+
31
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
32
+ return self.transform(image).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ class Postprocessor:
35
  def __init__(self):
36
  palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
37
  colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
38
  self.colors = (colors % 255).numpy().astype("uint8")
39
+
40
+ def postprocess(self, output: torch.Tensor) -> Image.Image:
41
+ output_predictions = output.argmax(0)
42
+ colorized_output = Image.fromarray(output_predictions.byte().cpu().numpy(), mode='P')
43
  colorized_output.putpalette(self.colors.ravel())
44
  return colorized_output
45
 
46
+ class Segmentation:
47
+ def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
48
+ self.model_loader = model_loader
49
+ self.preprocessor = preprocessor
50
+ self.postprocessor = postprocessor
51
+
52
+ def segment(self, image: Image.Image, selected_model: str) -> Image.Image:
53
+ model = self.model_loader.load_model(selected_model)
54
+ input_tensor = self.preprocessor.preprocess(image)
55
+
56
+ if torch.cuda.is_available():
57
+ input_tensor = input_tensor.to("cuda")
58
+ model = model.to("cuda")
59
 
60
+ with torch.no_grad():
61
+ output = model(input_tensor)['out'][0]
62
+ return self.postprocessor.postprocess(output)
 
 
 
 
63
 
64
  class GradioApp:
65
+ def __init__(self, segmentation: Segmentation):
66
+ self.segmentation = segmentation
67
+
68
  def launch(self):
69
  with gr.Blocks() as demo:
70
  gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
71
  gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
 
 
 
 
 
72
  with gr.Row():
73
  with gr.Column():
74
+ upload_image = gr.Image(type='pil', label="Upload Image")
75
+ self.model_dropdown = gr.Dropdown(choices=list(SEGMENTATION_MODELS.keys()), label="Select Model")
76
+ segment_button = gr.Button("Segment")
77
  with gr.Column():
78
+ output_image = gr.Image(type='pil', label="Segmented Output")
79
+ segment_button.click(fn=self.segmentation.segment, inputs=[upload_image, self.model_dropdown], outputs=output_image)
 
 
 
80
  gr.Markdown("### Example Images")
81
  gr.Examples(
82
  examples=[
 
84
  ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
85
  ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
86
  ],
87
+ inputs=upload_image,
88
+ outputs=output_image,
89
  label="Click an example to use it"
90
  )
 
91
  demo.launch()
92
 
93
  if __name__ == "__main__":
94
+ model_loader = ModelLoader(SEGMENTATION_MODELS)
95
+ preprocessor = Preprocessor()
96
+ postprocessor = Postprocessor()
97
+ segmentation = Segmentation(model_loader, preprocessor, postprocessor)
98
+ app = GradioApp(segmentation)
99
  app.launch()