File size: 2,508 Bytes
d29ef6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

import torch
import numpy as np
from PIL import Image

def resize_image(image)->Image.Image:
    pixel_number = 1024*1024
    granularity_val = 64
    ratio = image.size[0] / image.size[1]
    width = int((pixel_number * ratio) ** 0.5)
    width = width - (width % granularity_val)
    height = int(pixel_number / width)
    height = height - (height % granularity_val)
    return image.resize((width, height))

def get_masked_background_image(image, image_mask)->tuple:
    image_mask_pil = image_mask.resize(image.size) # fg is white
    image = np.array(image.convert("RGB")).transpose(2, 0, 1).astype(np.float32) / 255.0
    image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0
    image[:,image_mask < 0.5] = 0  # mask background
    return image, image_mask

def get_control_image_tensor(vae, image, mask)->torch.Tensor:
        masked_image, image_mask = get_masked_background_image(image, mask)
        masked_image_tensor = torch.from_numpy(masked_image)
        masked_image_tensor = (masked_image_tensor - 0.5) / 0.5 # normalize for vae
        masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda:0")
        # encode the image to get the control latents
        control_latents = vae.encode(  
                masked_image_tensor[:, :3, :, :].to(vae.dtype)
            ).latent_dist.sample()   
        control_latents = control_latents * vae.config.scaling_factor 

        mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, None, ...].to(device="cuda:0")
        mask_tensor = torch.where(mask_tensor > 0.5, 1.0, 0) # binarize the mask
        mask_resized = torch.nn.functional.interpolate(mask_tensor, size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
        control_tensor = torch.cat([control_latents, mask_resized], dim=1)
        return control_tensor

def remove_bg_from_image(image_path: str)->Image.Image:
    from transformers import pipeline
    pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
    mask = pipe(image_path, return_mask = True) # outputs a pillow mask
    return mask

def paste_fg_over_image(gen_image: Image.Image, orig_image: Image.Image, fg_mask: Image.Image)->Image.Image:
    fg_mask = fg_mask.convert("L")
    fg_mask = fg_mask.resize(orig_image.size, Image.NEAREST)
    gen_image = gen_image.convert("RGBA")
    orig_image = orig_image.convert("RGBA")
    gen_image.paste(orig_image, (0, 0), fg_mask)
    return gen_image.convert("RGB")