import gradio as gr import os import torch import numpy as np import cv2 import matplotlib.pyplot as plt from typing import Tuple, Dict from timeit import default_timer as timer from skimage import io, transform import os import base64 import json import torch.nn.functional as F from model import create_sam_model # 1.Setup variables device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" # 2.Model preparation and load save weights medsam_model = create_sam_model(model_type,checkpoint,device) # 3.Predict fn def show_mask(mask, ax): color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) @torch.no_grad() def medsam_inference(medsam_model, img_embed, box_1024, H, W): box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) if len(box_torch.shape) == 2: box_torch = box_torch[:, None, :] # (B, 1, 4) sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( points=None, boxes=box_torch, masks=None, ) low_res_logits, _ = medsam_model.mask_decoder( image_embeddings=img_embed, # (B, 256, 64, 64) image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) multimask_output=False, ) low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) low_res_pred = F.interpolate( low_res_pred, size=(H, W), mode="bilinear", align_corners=False, ) # (1, 1, gt.shape) low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) medsam_seg = (low_res_pred > 0.5).astype(np.uint8) return medsam_seg def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = timer() # Transform the target image and add a batch dimension img_np = np.array(img) # Convierte de BGR a RGB si es necesario if img_np.shape[-1] == 3: # Asegura que sea una imagen en color img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) if len(img_np.shape) == 2: img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) else: img_3c = img_np H, W, _ = img_3c.shape # %% image preprocessing img_1024 = transform.resize( img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True ).astype(np.uint8) img_1024 = (img_1024 - img_1024.min()) / np.clip( img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None ) # normalize to [0, 1], (H, W, 3) # convert the shape to (3, H, W) img_1024_tensor = ( torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) ) # Put model into evaluation mode and turn on inference mode medsam_model.eval() with torch.inference_mode(): image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64) # define the inputbox input_box = np.array([[125, 275, 190, 350]]) # transfer box_np t0 1024x1024 scale box_1024 = input_box / np.array([W, H, W, H]) * 1024 medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W) pred_time = round(timer() - start_time, 5) fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(img_3c) show_box(input_box[0], ax[0]) ax[0].set_title("Input Image and Bounding Box") ax[1].imshow(img_3c) show_mask(medsam_seg, ax[1]) show_box(input_box[0], ax[1]) ax[1].set_title("MedSAM Segmentation") # Calculate the prediction time image_embedding = image_embedding.cpu().numpy().tobytes() # Serialize the response data to JSON format serialized_data = json.dumps([base64.b64encode(image_embedding).decode('ascii')]) # Return the prediction dictionary and prediction time return fig, pred_time,serialized_data # 4. Gradio app # Create title, description and article strings title = "MedSam" description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder." article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ." # Create examples list from "examples/" directory example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Plot(label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)"), gr.JSON(label="Embedding Image")], # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch(debug=False, # print errors locally? share=True) # generate a publically shareable URL?