Spaces:
Sleeping
Sleeping
| from models import MaskDecoderHQ | |
| from ppc_decoder import sam_decoder_reg | |
| from segment_anything import sam_model_registry | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from utils.transforms import ResizeLongestSide | |
| from typing import List | |
| trans = ResizeLongestSide(target_length=1024) | |
| def save_prob_visualization(prob, filename="prob_visualization.png"): | |
| """ | |
| 可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地 | |
| :param prob: 形状为 1xwxh 的 tensor | |
| :param filename: 保存的文件名,默认为 'prob_visualization.png' | |
| """ | |
| # 将 prob 转换为 numpy 数组 | |
| prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh | |
| # 使用 plt.imshow 可视化 | |
| plt.imshow(prob_np) | |
| # , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示 | |
| plt.axis('off') # 关闭坐标轴 | |
| # 保存图像 | |
| plt.savefig(filename, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| print(f"Probability map saved as {filename}") | |
| def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor: | |
| """Pad the input tensor to a square shape with the specified target size.""" | |
| # Get the current height and width of the image | |
| h, w = x.shape[-2:] | |
| # Calculate padding for height and width | |
| padh = target_size - h | |
| padw = target_size - w | |
| # Pad the tensor to the target size | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def remove_none_values(input_dict): | |
| """ | |
| Remove all items with None as their value from the dictionary. | |
| Args: | |
| input_dict (dict): The dictionary from which to remove None values. | |
| Returns: | |
| dict: A new dictionary with None values removed. | |
| """ | |
| return {key: value for key, value in input_dict.items() if value is not None} | |
| class PPC_SAM(): | |
| def __init__(self, model_type="vit_h", | |
| ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth", | |
| ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth", | |
| ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth", | |
| device = "cpu") -> None: | |
| # Call the parent class's __init__ method first | |
| self.device = device | |
| # Initialize the decoders | |
| self.sam_hq_decoder = MaskDecoderHQ(model_type) | |
| self.ppc_decoder = sam_decoder_reg['default']() | |
| # Load state dictionaries | |
| model_state_hq = torch.load(ckpt_hq, map_location=device) | |
| self.sam_hq_decoder.load_state_dict(model_state_hq) | |
| print(f"Loaded HQ decoder checkpoint from {ckpt_hq}") | |
| model_state_ppc = torch.load(ckpt_ppc, map_location=device) | |
| self.ppc_decoder.load_state_dict(model_state_ppc) | |
| print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}") | |
| # Initialize the SAM model | |
| self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device) | |
| def predict(self, prompts, multimask_ouput=False): | |
| with torch.no_grad(): | |
| self.sam = self.sam.to(self.device) | |
| self.sam_hq_decoder = self.sam_hq_decoder.to(self.device) | |
| self.ppc_decoder = self.ppc_decoder.to(self.device) | |
| batch_input = remove_none_values(prompts[0]) | |
| original_size = batch_input["image"].shape[:2] | |
| batch_input["original_size"] = original_size | |
| input_image = trans.apply_image(batch_input["image"]) | |
| input_image_torch = torch.as_tensor(input_image, device=self.device) | |
| input_image_torch = input_image_torch.permute(2, 0, 1).contiguous() | |
| batch_input["image"] = input_image_torch | |
| if "boxes" in batch_input: | |
| batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size) | |
| if "point_coords" in batch_input: | |
| batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size) | |
| batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput) | |
| batch_len = len(batched_output) | |
| encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
| image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
| sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
| dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
| masks_sam_in_hq, masks_hq = self.sam_hq_decoder( | |
| image_embeddings=encoder_embedding, | |
| image_pe=image_pe, | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_ouput, | |
| hq_token_only=False, | |
| interm_embeddings=interm_embeddings, | |
| ) | |
| masks_sam = batched_output[0]["masks"] | |
| input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float() | |
| mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq) | |
| rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size) | |
| rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size) | |
| stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1) | |
| return stacked_masks, None, None |