from transformers import SegformerForSemanticSegmentation from transformers import SegformerImageProcessor from PIL import Image import gradio as gr import numpy as np import random import cv2 import torch image_list = [ "data/1.png", "data/2.png", "data/3.png", "data/4.png", ] model_path = ['deprem-ml/deprem_satellite_semantic_whu'] def visualize_instance_seg_mask(mask): # Initialize image with zeros with the image resolution # of the segmentation mask and 3 channels image = np.zeros((mask.shape[0], mask.shape[1], 3)) # Create labels labels = np.unique(mask) label2color = { label: ( random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), ) for label in labels } for height in range(image.shape[0]): for width in range(image.shape[1]): image[height, width, :] = label2color[mask[height, width]] image = image / 255 return image def Segformer_Segmentation(image_path, model_id): output_save = "output.png" test_image = Image.open(image_path) model = SegformerForSemanticSegmentation.from_pretrained(model_id) proccessor = SegformerImageProcessor(model_id) inputs = proccessor(images=test_image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) result = proccessor.post_process_semantic_segmentation(outputs)[0] result = np.array(result) result = visualize_instance_seg_mask(result) cv2.imwrite(output_save, result*255) return image_path, output_save examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[1], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[2], "deprem-ml/deprem_satellite_semantic_whu"], [image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]] title = "Deprem ML - Segformer Semantic Segmentation" app = gr.Blocks() with app: gr.HTML("

{}

".format(title)) with gr.Row(): with gr.Column(): gr.Markdown("Video") input_video = gr.Image(type='filepath') model_id = gr.Dropdown(value=model_path[0], choices=model_path) input_video_button = gr.Button(value="Predict") with gr.Column(): output_orijinal_image = gr.Image(type='filepath') with gr.Column(): output_mask_image = gr.Image(type='filepath') gr.Examples(examples, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True) input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image]) app.launch()