kadirnar's picture
Update app.py
4551826
raw history blame
No virus
2.86 kB
from transformers import SegformerForSemanticSegmentation
from transformers import SegformerImageProcessor
from PIL import Image
import gradio as gr
import numpy as np
import random
import cv2
import torch
image_list = [
"data/1.png",
"data/2.png",
"data/3.png",
"data/4.png",
]
model_path = ['deprem-ml/deprem_satellite_semantic_whu']
def visualize_instance_seg_mask(mask):
# Initialize image with zeros with the image resolution
# of the segmentation mask and 3 channels
image = np.zeros((mask.shape[0], mask.shape[1], 3))
# Create labels
labels = np.unique(mask)
label2color = {
label: (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
for label in labels
}
for height in range(image.shape[0]):
for width in range(image.shape[1]):
image[height, width, :] = label2color[mask[height, width]]
image = image / 255
return image
def Segformer_Segmentation(image_path, model_id):
output_save = "output.png"
test_image = Image.open(image_path)
model = SegformerForSemanticSegmentation.from_pretrained(model_id)
proccessor = SegformerImageProcessor(model_id)
inputs = proccessor(images=test_image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
result = proccessor.post_process_semantic_segmentation(outputs)[0]
result = np.array(result)
result = visualize_instance_seg_mask(result)
cv2.imwrite(output_save, result*255)
return image_path, output_save
examples = [[image_list[0], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[1], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[2], "deprem-ml/deprem_satellite_semantic_whu"],
[image_list[3], "deprem-ml/deprem_satellite_semantic_whu"]]
title = "Deprem ML - Segformer Semantic Segmentation"
app = gr.Blocks()
with app:
gr.HTML("<h1 style='text-align: center'>{}</h1>".format(title))
with gr.Row():
with gr.Column():
gr.Markdown("Video")
input_video = gr.Image(type='filepath')
model_id = gr.Dropdown(value=model_path[0], choices=model_path)
input_video_button = gr.Button(value="Predict")
with gr.Column():
output_orijinal_image = gr.Image(type='filepath')
with gr.Column():
output_mask_image = gr.Image(type='filepath')
gr.Examples(examples, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image], fn=Segformer_Segmentation, cache_examples=True)
input_video_button.click(Segformer_Segmentation, inputs=[input_video, model_id], outputs=[output_orijinal_image, output_mask_image])
app.launch()