etemkocaaslan's picture
Update app.py
61e99d3 verified
raw
history blame contribute delete
No virus
4.27 kB
import torch
from torchvision import models, transforms
from PIL import Image, ImageDraw
import gradio as gr
from typing import Union
SEGMENTATION_MODELS = {
"deeplabv3_resnet101": models.segmentation.deeplabv3_resnet101,}
class ModelLoader:
def __init__(self, model_dict: dict):
self.model_dict = model_dict
def load_model(self, model_name: str) -> torch.nn.Module:
model_name_lower = model_name.lower()
if model_name_lower in self.model_dict:
model_class = self.model_dict[model_name_lower]
model = model_class(pretrained=True)
model.eval()
return model
else:
raise ValueError(f"Model {model_name} is not supported")
class Preprocessor:
def __init__(self, transform: transforms.Compose = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])):
self.transform = transform
def preprocess(self, image: Image.Image) -> torch.Tensor:
return self.transform(image).unsqueeze(0)
class Postprocessor:
def __init__(self):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def postprocess(self, output: torch.Tensor) -> Image.Image:
output_predictions = output.argmax(0)
colorized_output = Image.fromarray(output_predictions.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmentation:
def __init__(self, model_loader: ModelLoader, preprocessor: Preprocessor, postprocessor: Postprocessor):
self.model_loader = model_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
def segment(self, image: Image.Image, selected_model: str) -> Image.Image:
model = self.model_loader.load_model(selected_model)
input_tensor = self.preprocessor.preprocess(image)
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
model = model.to("cuda")
with torch.no_grad():
output = model(input_tensor)['out'][0]
return self.postprocessor.postprocess(output)
class GradioApp:
def __init__(self, segmentation: Segmentation):
self.segmentation = segmentation
def launch(self):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
with gr.Row():
with gr.Column():
upload_image = gr.Image(type='pil', label="Upload Image")
self.model_dropdown = gr.Dropdown(choices=list(SEGMENTATION_MODELS.keys()), label="Select Model")
segment_button = gr.Button("Segment")
with gr.Column():
output_image = gr.Image(type='pil', label="Segmented Output")
segment_button.click(fn=self.segmentation.segment, inputs=[upload_image, self.model_dropdown], outputs=output_image)
gr.Markdown("### Example Images")
gr.Examples(
examples=[
["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
],
inputs=upload_image,
outputs=output_image,
label="Click an example to use it"
)
demo.launch()
if __name__ == "__main__":
model_loader = ModelLoader(SEGMENTATION_MODELS)
preprocessor = Preprocessor()
postprocessor = Postprocessor()
segmentation = Segmentation(model_loader, preprocessor, postprocessor)
app = GradioApp(segmentation)
app.launch()