Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import torch | |
import numpy as np | |
import cv2 | |
import matplotlib.pyplot as plt | |
import base64 | |
import json | |
from typing import Tuple, Dict | |
from timeit import default_timer as timer | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
from segment_anything.utils.onnx import SamOnnxModel | |
import torch.nn.functional as F | |
from model import create_sam_model | |
# 1.Setup variables | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
checkpoint = "sam_vit_b_01ec64.pth" | |
model_type = "vit_b" | |
# 2.Model preparation and load save weights | |
medsam_model = create_sam_model(model_type,checkpoint,device) | |
mask_generator = SamAutomaticMaskGenerator( | |
model=medsam_model, | |
points_per_side=32, | |
pred_iou_thresh=0.86, | |
stability_score_thresh=0.92, | |
crop_n_layers=1, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=100, # Requires open-cv to run post-processing | |
) | |
# 3.Predict fn | |
def show_anns(anns): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) | |
img[:,:,3] = 0 | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3), [0.35]]) | |
img[m] = color_mask | |
ax.imshow(img) | |
def predict(img) -> Tuple[Dict, float]: | |
"""Transforms and performs a prediction on img and returns prediction and time taken. | |
""" | |
# Start the timer | |
start_time = timer() | |
# Transform the target image and add a batch dimension | |
img_np = np.array(img) | |
# Convierte de BGR a RGB si es necesario | |
image = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) | |
masks = mask_generator.generate(image) | |
# Calculate the prediction time | |
pred_time = round(timer() - start_time, 5) | |
fig = plt.figure(figsize=(20,20)) | |
plt.imshow(image) | |
show_anns(masks) | |
plt.axis('off') | |
# Return the prediction dictionary and prediction time | |
return fig, pred_time | |
# 4. Gradio app | |
# Create title, description and article strings | |
title = "MedSam" | |
description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder." | |
article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ." | |
# Create examples list from "examples/" directory | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
# Create the Gradio demo | |
demo = gr.Interface(fn=predict, # mapping function from input to output | |
inputs=gr.Image(type="pil"), # what are the inputs? | |
outputs=[gr.Plot(label="Predictions"), # what are the outputs? | |
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs | |
examples=example_list, | |
title=title, | |
description=description, | |
article=article) | |
# Launch the demo! | |
demo.launch(debug=False, # print errors locally? | |
share=True) # generate a publically shareable URL? | |