File size: 2,810 Bytes
4450790 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torchvision.transforms.v2 as T
import torch.nn.functional as F
from .utils import expand_mask
class LoadCLIPSegModels:
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
}
RETURN_TYPES = ("CLIP_SEG",)
FUNCTION = "execute"
CATEGORY = "essentials/segmentation"
def execute(self):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
return ((processor, model),)
class ApplyCLIPSeg:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip_seg": ("CLIP_SEG",),
"image": ("IMAGE",),
"prompt": ("STRING", { "multiline": False, "default": "" }),
"threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }),
"smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }),
"dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }),
"blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }),
},
}
RETURN_TYPES = ("MASK",)
FUNCTION = "execute"
CATEGORY = "essentials/segmentation"
def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur):
processor, model = clip_seg
imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy()
outputs = []
for i in imagenp:
inputs = processor(text=prompt, images=[i], return_tensors="pt")
out = model(**inputs, interpolate_pos_encoding=True)
out = out.logits.unsqueeze(1)
out = torch.sigmoid(out[0][0])
out = (out > threshold)
outputs.append(out)
del imagenp
outputs = torch.stack(outputs, dim=0)
if smooth > 0:
if smooth % 2 == 0:
smooth += 1
outputs = T.functional.gaussian_blur(outputs, smooth)
outputs = outputs.float()
if dilate != 0:
outputs = expand_mask(outputs, dilate, True)
if blur > 0:
if blur % 2 == 0:
blur += 1
outputs = T.functional.gaussian_blur(outputs, blur)
# resize to original size
outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1)
return (outputs,)
SEG_CLASS_MAPPINGS = {
"ApplyCLIPSeg+": ApplyCLIPSeg,
"LoadCLIPSegModels+": LoadCLIPSegModels,
}
SEG_NAME_MAPPINGS = {
"ApplyCLIPSeg+": "๐ง Apply CLIPSeg",
"LoadCLIPSegModels+": "๐ง Load CLIPSeg Models",
} |