Spaces:
Sleeping
Sleeping
from transformers import SamModel, SamConfig, SamProcessor | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
def Predict(filename): | |
modelConfiguration = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
trainedModel = SamModel(config=modelConfiguration) | |
trainedModel.load_state_dict(torch.load("checkpoint_2.pth", map_location=torch.device('cpu'))) | |
device = "cpu" | |
trainedModel.to(device) | |
mapping = [(0, 0, 0, 0), (1, 1, 1, 0), (255/255, 255/255, 10/255, 1)] | |
colorMap = LinearSegmentedColormap.from_list("BlackToTransparent", mapping) | |
def get_full_image_bounding_box(image): | |
height, width = image.shape[:2] | |
return [0, 0, width, height] | |
imgPath = "./userImage/" + filename | |
userImage = Image.open(imgPath) | |
reshapeTransform = transforms.Resize((256, 256)) | |
resized_image = reshapeTransform(userImage) | |
resized_image_rgb = resized_image.convert("RGB") | |
resized_image_np = np.array(resized_image_rgb) | |
full_image_box = get_full_image_bounding_box(np.array(resized_image_np)) | |
inputs = processor(resized_image_rgb, input_boxes=[[full_image_box]], return_tensors="pt") | |
inputs = {k: v.cpu() for k, v in inputs.items()} | |
trainedModel.eval() | |
with torch.no_grad(): | |
outputs = trainedModel(**inputs, multimask_output=False) | |
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() | |
medsam_seg = (medsam_seg_prob > 0.65).astype(np.uint8) | |
result_image_path = "./results/" + filename | |
plt.imshow(np.array(resized_image_rgb)) # Assuming the image is RGB | |
plt.imshow(medsam_seg, cmap=colorMap, alpha=1) # Overlay predicted segmentation | |
plt.axis('off') # Turn off axis | |
plt.savefig(result_image_path, bbox_inches='tight', pad_inches=0) # Save the figure without extra white space | |
plt.close() | |
print(f"Predicted image saved to: {result_image_path}") | |
return result_image_path | |