Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from ..utils import models_dir, np2tensor | |
| # TODO: check if I can make a torch script device independant | |
| # for now I forced it to use cuda. | |
| class MTB_LoadVitMatteModel: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "kind": (("Composition-1K", "Distinctions-646"),), | |
| "autodownload": ("BOOLEAN", {"default": True}), | |
| }, | |
| } | |
| RETURN_TYPES = ("VITMATTE_MODEL",) | |
| RETURN_NAMES = ("torch_script",) | |
| CATEGORY = "mtb/vitmatte" | |
| FUNCTION = "execute" | |
| def execute(self, *, kind: str, autodownload: bool): | |
| dest = models_dir / "vitmatte" | |
| dest.mkdir(exist_ok=True) | |
| name = "dist" if kind == "Distinctions-646" else "com" | |
| file = hf_hub_download( | |
| repo_id="melmass/pytorch-scripts", | |
| filename=f"vitmatte_b_{name}.pt", | |
| local_dir=dest.as_posix(), | |
| local_files_only=not autodownload, | |
| ) | |
| model = torch.jit.load(file).to("cuda") | |
| return (model,) | |
| class MTB_GenerateTrimap: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| # "image": ("IMAGE",), | |
| "mask": ("MASK",), | |
| "erode": ("INT", {"default": 10}), | |
| "dilate": ("INT", {"default": 10}), | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| RETURN_NAMES = ("trimap",) | |
| CATEGORY = "mtb/vitmatte" | |
| FUNCTION = "execute" | |
| def execute( | |
| self, | |
| # image:torch.Tensor, | |
| mask: torch.Tensor, | |
| erode: int = 10, | |
| dilate: int = 10, | |
| ): | |
| # TODO: not sure what's the most practical between IMAGE or MASK | |
| # image = image.to("cuda").half() | |
| mask = mask.to("cuda").half() | |
| trimaps = [] | |
| for m in mask: | |
| mask_arr = m.squeeze(0).to(torch.uint8).cpu().numpy() * 255 | |
| erode_kernel = np.ones((erode, erode), np.uint8) | |
| dilate_kernel = np.ones((dilate, dilate), np.uint8) | |
| eroded = cv2.erode(mask_arr, erode_kernel, iterations=5) | |
| dilated = cv2.dilate(mask_arr, dilate_kernel, iterations=5) | |
| trimap = np.zeros_like(mask_arr) | |
| trimap[dilated == 255] = 128 | |
| trimap[eroded == 255] = 255 | |
| trimaps.append(trimap) | |
| return (np2tensor(trimaps),) | |
| class MTB_ApplyVitMatte: | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "model": ("VITMATTE_MODEL",), | |
| "image": ("IMAGE",), | |
| "trimap": ("IMAGE",), | |
| "returns": (("RGB", "RGBA"),), | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE", "MASK") | |
| RETURN_NAMES = ("image (rgba)", "mask") | |
| CATEGORY = "mtb/utils" | |
| FUNCTION = "execute" | |
| def execute( | |
| self, model, image: torch.Tensor, trimap: torch.Tensor, returns: str | |
| ): | |
| im_count = image.shape[0] | |
| tm_count = trimap.shape[0] | |
| if im_count != tm_count: | |
| raise ValueError("image and trimap must have the same batch size") | |
| outputs_m: list[torch.Tensor] = [] | |
| outputs_i: list[torch.Tensor] = [] | |
| for i, im in enumerate(image): | |
| tm = trimap[i].half().unsqueeze(2).permute(2, 0, 1).to("cuda") | |
| im = im.half().permute(2, 0, 1).to("cuda") | |
| inputs = {"image": im.unsqueeze(0), "trimap": tm.unsqueeze(0)} | |
| fine_mask = model(inputs) | |
| foreground = im * fine_mask + (1 - fine_mask) | |
| if returns == "RGBA": | |
| rgba_image = torch.cat( | |
| (foreground, fine_mask.unsqueeze(0)), dim=0 | |
| ) | |
| outputs_i.append(rgba_image.unsqueeze(0)) | |
| else: | |
| outputs_i.append(foreground.unsqueeze(0)) | |
| outputs_m.append(fine_mask.unsqueeze(0)) | |
| result_m = torch.cat(outputs_m, dim=0) | |
| result_i = torch.cat(outputs_i, dim=0) | |
| return (result_i.permute(0, 2, 3, 1), result_m) | |
| __nodes__ = [MTB_LoadVitMatteModel, MTB_GenerateTrimap, MTB_ApplyVitMatte] | |