kitooo's picture
initial commit
fc9bdaf verified
raw
history blame
No virus
2.08 kB
import torch
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from transformers import SamModel, SamProcessor
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
model = SamModel.from_pretrained('hmdliu/sidewalks-seg-base')
model.to(device)
def segment_sidewalk(image, threshold):
# init data
width, height = image.size
prompt = [0, 0, width, height]
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
prediction = (prob_map > threshold).float()
prob_map, prediction = prob_map.numpy(), prediction.numpy()
# visualize results
save_image(image, 'image.png')
save_image(prob_map, 'prob.png', cmap='jet')
save_image(prediction, 'mask.png', cmap='gray')
return Image.open('image.png'), Image.open('mask.png'), Image.open('prob.png')
def save_image(image, path, **kwargs):
plt.figure(figsize=(8, 8))
plt.imshow(image, interpolation='nearest', **kwargs)
plt.axis('off')
plt.tight_layout()
plt.savefig(path, bbox_inches='tight', pad_inches=0)
plt.close()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input = gr.Image(type='pil', label='TIFF Image')
threshold_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
segment_button = gr.Button('Segment')
with gr.Column():
prediction = gr.Image(type='pil', label='Segmentation Result')
prob_map = gr.Image(type='pil', label='Probability Map')
segment_button.click(
segment_image,
inputs=[image_input, threshold_slider],
outputs=[image_input, prediction, prob_map]
)
demo.launch(debug=True, show_error=True)