ai_project / sidewalkModel.py
Rrrrr26's picture
Update sidewalkModel.py
f6cdbdd verified
from transformers import SamModel, SamConfig, SamProcessor
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
def Predict(filename):
modelConfiguration = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
trainedModel = SamModel(config=modelConfiguration)
trainedModel.load_state_dict(torch.load("checkpoint_2.pth", map_location=torch.device('cpu')))
device = "cpu"
trainedModel.to(device)
mapping = [(0, 0, 0, 0), (1, 1, 1, 0), (255/255, 255/255, 10/255, 1)]
colorMap = LinearSegmentedColormap.from_list("BlackToTransparent", mapping)
def get_full_image_bounding_box(image):
height, width = image.shape[:2]
return [0, 0, width, height]
imgPath = "./userImage/" + filename
userImage = Image.open(imgPath)
reshapeTransform = transforms.Resize((256, 256))
resized_image = reshapeTransform(userImage)
resized_image_rgb = resized_image.convert("RGB")
resized_image_np = np.array(resized_image_rgb)
full_image_box = get_full_image_bounding_box(np.array(resized_image_np))
inputs = processor(resized_image_rgb, input_boxes=[[full_image_box]], return_tensors="pt")
inputs = {k: v.cpu() for k, v in inputs.items()}
trainedModel.eval()
with torch.no_grad():
outputs = trainedModel(**inputs, multimask_output=False)
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.65).astype(np.uint8)
result_image_path = "./results/" + filename
plt.imshow(np.array(resized_image_rgb)) # Assuming the image is RGB
plt.imshow(medsam_seg, cmap=colorMap, alpha=1) # Overlay predicted segmentation
plt.axis('off') # Turn off axis
plt.savefig(result_image_path, bbox_inches='tight', pad_inches=0) # Save the figure without extra white space
plt.close()
print(f"Predicted image saved to: {result_image_path}")
return result_image_path