File size: 4,908 Bytes
537fd2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)  # nopep8
warnings.filterwarnings("ignore", category=UserWarning)  # nopep8
import os
import math
from tqdm import tqdm
import torch
from PIL import Image, ImageFilter
from scipy.ndimage import binary_dilation
import numpy as np

from captioner import init as init_captioner, derive_caption
from upscaler import init as init_upscaler
from segmenter import init as init_segmenter, segment
from depth_estimator import init as init_depth_estimator, get_depth_map
from pipeline import init as init_pipeline, run_pipeline
from image_utils import ensure_resolution, crop_centered

developer_mode = os.getenv('DEV_MODE', False)

init_captioner()
init_upscaler()
init_segmenter()
init_depth_estimator()
init_pipeline()

torch.cuda.empty_cache()

POSITIVE_PROMPT_SUFFIX = "commercial product photography, 24mm lens f/8"
NEGATIVE_PROMPT_SUFFIX = "cartoon, drawing, anime, semi-realistic, illustration, painting, art, text, greyscale, (black and white), lens flare, watermark, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, floating, levitating"

MEGAPIXELS = 1.0


def replace_background(
    original,
    positive_prompt,
    negative_prompt,
    options,
):
    pbar = tqdm(total=7)

    print("Original size:", original.size)

    print("Captioning...")
    caption = derive_caption(original)
    pbar.update(1)

    print("Caption:", caption)

    torch.cuda.empty_cache()

    print(f"Ensuring resolution ({MEGAPIXELS}MP)...")
    resized = ensure_resolution(original, megapixels=MEGAPIXELS)
    pbar.update(1)

    print("Resized size:", resized.size)

    torch.cuda.empty_cache()

    print("Segmenting...")
    [cropped, crop_mask] = segment(resized)
    pbar.update(1)

    torch.cuda.empty_cache()

    print("Depth mapping...")
    depth_map = get_depth_map(resized)
    pbar.update(1)

    torch.cuda.empty_cache()

    print("Feathering the depth map...")

    # Convert crop mask to grayscale and to numpy array
    crop_mask_np = np.array(crop_mask.convert('L'))

    # Convert to binary and dilate (grow) the edges
    # adjust threshold as needed
    crop_mask_binary = crop_mask_np > options.get(
        'depth_map_feather_threshold')
    # adjust iterations as needed
    dilated_mask = binary_dilation(
        crop_mask_binary, iterations=options.get('depth_map_dilation_iterations'))

    # Convert back to PIL Image
    dilated_mask = Image.fromarray((dilated_mask * 255).astype(np.uint8))

    # Apply Gaussian blur and normalize
    dilated_mask_blurred = dilated_mask.filter(
        ImageFilter.GaussianBlur(radius=options.get('depth_map_blur_radius')))
    dilated_mask_blurred_np = np.array(dilated_mask_blurred) / 255.0

    # Normalize depth map, apply blurred, dilated mask, and scale back
    depth_map_np = np.array(depth_map.convert('L')) / 255.0
    masked_depth_map_np = depth_map_np * dilated_mask_blurred_np
    masked_depth_map_np = (masked_depth_map_np * 255).astype(np.uint8)

    # Convert back to PIL Image
    masked_depth_map = Image.fromarray(masked_depth_map_np).convert('RGB')

    pbar.update(1)

    final_positive_prompt = f"{caption}, {positive_prompt}, {POSITIVE_PROMPT_SUFFIX}"
    final_negative_prompt = f"{negative_prompt}, {NEGATIVE_PROMPT_SUFFIX}"

    print("Final positive prompt:", final_positive_prompt)
    print("Final negative prompt:", final_negative_prompt)

    print("Generating...")

    generated_images = run_pipeline(
        positive_prompt=final_positive_prompt,
        negative_prompt=final_negative_prompt,
        image=[masked_depth_map],
        seed=options.get('seed')
    )
    pbar.update(1)

    torch.cuda.empty_cache()

    print("Compositing...")

    composited_images = [
        Image.alpha_composite(
            generated_image.convert('RGBA'),
            crop_centered(cropped, generated_image.size)
        ) for generated_image in generated_images
    ]
    pbar.update(1)
    pbar.close()

    print("Done!")

    if developer_mode:
        pre_processing_images = [
            [resized, "Resized"],
            [crop_mask, "Crop mask"],
            [cropped, "Cropped"],
            [depth_map, "Depth map"],
            [dilated_mask, "Dilated mask"],
            [dilated_mask_blurred, "Dilated mask blurred"],
            [masked_depth_map, "Masked depth map"]
        ]
        return [
            composited_images,
            generated_images,
            pre_processing_images,
            caption,
        ]
    else:
        return [composited_images, None, None, None]