import gradio as gr import torch from torchvision import transforms from PIL import Image # Load the segmentation model (replace 'path/to/lightmed_model' with the actual path) model_path = 'medsam_lite/lite_medsam.pth' segmentation_model = torch.load(model_path, map_location=torch.device('cpu')) segmentation_model.eval() # Define the preprocessing function for the input image def preprocess(image): # Resize the image to match the model's expected input size transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) img = Image.fromarray(image) img = transform(img).unsqueeze(0) return img # Define the segmentation function def segment_image(input_image): # Preprocess the input image input_tensor = preprocess(input_image) # Perform segmentation using the model with torch.no_grad(): output = segmentation_model(input_tensor) # Convert the output tensor to a segmented image segmented_image = torch.argmax(output, dim=1).squeeze().numpy() # Return the segmented image return segmented_image # Define the Gradio interface iface = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil", preprocess=preprocess), outputs=gr.Image(type="numpy") ) # Launch the Gradio app iface.launch()