Spaces:
Sleeping
Sleeping
| from transformers import SamModel, SamConfig, SamProcessor | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import LinearSegmentedColormap | |
| def Predict(filename): | |
| modelConfiguration = SamConfig.from_pretrained("facebook/sam-vit-base") | |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
| trainedModel = SamModel(config=modelConfiguration) | |
| trainedModel.load_state_dict(torch.load("checkpoint_2.pth", map_location=torch.device('cpu'))) | |
| device = "cpu" | |
| trainedModel.to(device) | |
| mapping = [(0, 0, 0, 0), (1, 1, 1, 0), (255/255, 255/255, 10/255, 1)] | |
| colorMap = LinearSegmentedColormap.from_list("BlackToTransparent", mapping) | |
| def get_full_image_bounding_box(image): | |
| height, width = image.shape[:2] | |
| return [0, 0, width, height] | |
| imgPath = "./userImage/" + filename | |
| userImage = Image.open(imgPath) | |
| reshapeTransform = transforms.Resize((256, 256)) | |
| resized_image = reshapeTransform(userImage) | |
| resized_image_rgb = resized_image.convert("RGB") | |
| resized_image_np = np.array(resized_image_rgb) | |
| full_image_box = get_full_image_bounding_box(np.array(resized_image_np)) | |
| inputs = processor(resized_image_rgb, input_boxes=[[full_image_box]], return_tensors="pt") | |
| inputs = {k: v.cpu() for k, v in inputs.items()} | |
| trainedModel.eval() | |
| with torch.no_grad(): | |
| outputs = trainedModel(**inputs, multimask_output=False) | |
| medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
| medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() | |
| medsam_seg = (medsam_seg_prob > 0.65).astype(np.uint8) | |
| result_image_path = "./results/" + filename | |
| plt.imshow(np.array(resized_image_rgb)) # Assuming the image is RGB | |
| plt.imshow(medsam_seg, cmap=colorMap, alpha=1) # Overlay predicted segmentation | |
| plt.axis('off') # Turn off axis | |
| plt.savefig(result_image_path, bbox_inches='tight', pad_inches=0) # Save the figure without extra white space | |
| plt.close() | |
| print(f"Predicted image saved to: {result_image_path}") | |
| return result_image_path | |