File size: 4,449 Bytes
ffab2fd
 
 
 
 
931a3ea
ffab2fd
 
b3556d1
ffab2fd
b3556d1
931a3ea
b3556d1
ffab2fd
 
b3556d1
ffab2fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73e49ee
ffab2fd
 
73e49ee
 
 
ffab2fd
73e49ee
 
 
 
ffab2fd
 
 
 
 
 
73e49ee
 
 
 
3183ff3
 
 
 
 
 
 
 
ffab2fd
 
73e49ee
 
 
 
ffab2fd
73e49ee
 
 
 
 
ffab2fd
 
 
73e49ee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from PIL import Image
import math
import numpy as np
import torch
import PIL
import os
import cv2
import mediapipe as mp
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline

YOUR_TOKEN = os.environ.get("API_TOKEN") or True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = "runwayml/stable-diffusion-inpainting"

if torch.cuda.is_available():
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, revision="fp16", torch_dtype=torch.float16,
                                                          use_auth_token=YOUR_TOKEN).to(device)
else:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_path, use_auth_token=YOUR_TOKEN).to(device)


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = PIL.Image.new('RGB', size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def mediapipe_segmentation(image_file, mask_file):
    mp_drawing = mp.solutions.drawing_utils
    mp_selfie_segmentation = mp.solutions.selfie_segmentation

    # For static images:
    BG_COLOR = (0, 0, 0)  # gray
    MASK_COLOR = (255, 255, 255)  # white
    with mp_selfie_segmentation.SelfieSegmentation(model_selection=0) as selfie_segmentation:
        image = cv2.imread(image_file)
        image_height, image_width, _ = image.shape
        # Convert the BGR image to RGB before processing.
        results = selfie_segmentation.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

        # blurred_image = cv2.GaussianBlur(image,(55,55),0)
        # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
        # output_image = np.where(condition, image, blurred_image)

        # Draw selfie segmentation on the background image.
        # To improve segmentation around boundaries, consider applying a joint
        # bilateral filter to "results.segmentation_mask" with "image".
        condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
        # Generate solid color images for showing the output selfie segmentation mask.
        fg_image = np.zeros(image.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR
        output_image = np.where(condition, fg_image, bg_image)
        cv2.imwrite(mask_file, output_image)


def image_inpainting(prompt, image_path, mask_image_path, num_samples=4, is_origin=False):
    image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
    mask_image = PIL.Image.open(mask_image_path).convert("RGB").resize((512, 512))
    if not is_origin:
        guidance_scale = 7.5
        generator = torch.Generator(device=device).manual_seed(0)  # change the seed to get different results

        images = pipe(prompt=prompt, image=image, mask_image=mask_image, guidance_scale=guidance_scale, generator=generator,
                      num_images_per_prompt=num_samples).images
    else:
        images = pipe(prompt=prompt, image=image, mask_image=mask_image, num_images_per_prompt=num_samples).images

    # insert initial image in the list so we can compare side by side
    # images.insert(0, image)
    return image_grid(images, 2, math.ceil(num_samples/2))


title = "Person Matting & Stable Diffusion In-Painting"
description = "Inpainting Stable Diffusion <br/><b>mediapipe + Stable Diffusion<b/><br/>"


examples = [
    [
        "a bench in a field",
        'example.png'
    ]
]


def predict1(dict, prompt):
    dict['image'].save('image.png')
    # dict['mask'].save('mask.png')
    mediapipe_segmentation('image.png')
    image = image_inpainting(prompt, image_path='image.png', mask_image_path='m_mask.png', is_origin=False)
    return image


def predict2(dict, prompt):
    dict['image'].save('image.png')
    dict['mask'].save('mask.png')
    image = image_inpainting(prompt, image_path='image.png', mask_image_path='mask.png', is_origin=True)
    return image


image_input = gr.Image(source='upload', tool='sketch', type='pil')
prompt = gr.Textbox(label='prompt')

greeter_1 = gr.Interface(predict1, inputs=[image_input, prompt], outputs=gr.Image(label='auto'))
greeter_2 = gr.Interface(predict2, inputs=[image_input, prompt], outputs=gr.Image(label='paint'))
demo = gr.Parallel(greeter_1, greeter_2).launch(max_threads=True)