Image_Dissector / app.py
wendys-llc's picture
Update app.py
de37a0c verified
raw
history blame contribute delete
No virus
2.49 kB
from PIL import Image, ImageFilter
import numpy as np
from transformers import pipeline
import gradio as gr
import os
models = [
"facebook/detr-resnet-50-panoptic",
"CIDAS/clipseg-rd64-refined",
"facebook/maskformer-swin-large-ade",
"nvidia/segformer-b1-finetuned-cityscapes-1024-1024",
]
current_model = models[0]
#model = pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic")
pred = []
def img_resize(image):
width = 1280
width_percent = (width / float(image.size[0]))
height = int((float(image.size[1]) * float(width_percent)))
return image.resize((width, height))
def image_objects(image):
global pred
image = img_resize(image)
pred = model(image)
pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)]
return gr.Dropdown.update(choices = pred_object_list, interactive = True)
def get_seg(image, model_choice):
image = img_resize(image)
model = models[model_choice]
segment = pipeline("image-segmentation", model=f"{model}")
pred = segment(image)
pred_object_list = [str(i)+'_'+x['label'] for i, x in enumerate(pred)]
seg_box=[]
for i in range(len(pred)):
#object_number = int(object.split('_')[0])
mask_array = np.asarray(pred[i]['mask'])/255
image_array = np.asarray(image)
mask_array_three_channel = np.zeros_like(image_array)
mask_array_three_channel[:,:,0] = mask_array
mask_array_three_channel[:,:,1] = mask_array
mask_array_three_channel[:,:,2] = mask_array
segmented_image = image_array*mask_array_three_channel
seg_out=segmented_image.astype(np.uint8)
seg_box.append(seg_out)
return(seg_box,gr.Dropdown.update(choices = pred_object_list, interactive = True))
app = gr.Blocks()
with app:
gr.Markdown(
"""
## Image Dissector
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image",type="pil")
model_name = gr.Dropdown(show_label=False, choices=[m for m in models], type="index", value=current_model, interactive=True)
with gr.Column():
gal1=gr.Gallery(type="filepath").style(grid=6)
with gr.Row():
with gr.Column():
object_output = gr.Dropdown(label="Objects")
image_input.change(get_seg, inputs=[image_input, model_name], outputs=[gal1,object_output])
app.launch()