File size: 11,047 Bytes
028bd43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import concurrent.futures
import io
import os

import numpy as np
import oss2
import requests
from PIL import Image, ImageDraw, ImageFont

from .log import logger

# oss
access_key_id = os.getenv("ACCESS_KEY_ID")
access_key_secret = os.getenv("ACCESS_KEY_SECRET")
bucket_name = os.getenv("BUCKET_NAME")
endpoint = os.getenv("ENDPOINT")

bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
oss_path = "hejunjie.hjj/TransferAnythingHF"
oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"


def download_img_pil(index, img_url):
    # print(img_url)
    r = requests.get(img_url, stream=True)
    if r.status_code == 200:
        img = Image.open(io.BytesIO(r.content))
        return (index, img)
    else:
        logger.error(f"Fail to download: {img_url}")


def download_images(img_urls, batch_size):
    imgs_pil = [None] * batch_size
    # worker_results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        to_do = []
        for i, url in enumerate(img_urls):
            future = executor.submit(download_img_pil, i, url)
            to_do.append(future)

        for future in concurrent.futures.as_completed(to_do):
            ret = future.result()
            # worker_results.append(ret)
            index, img_pil = ret
            imgs_pil[index] = img_pil  # 按顺序排列url,后续下载关联的图片或者svg需要使用

    return imgs_pil


def upload_np_2_oss(input_image, name="cache.png", gallery=False):
    assert name.lower().endswith((".png", ".jpg")), name
    imgByteArr = io.BytesIO()
    if name.lower().endswith(".png"):
        Image.fromarray(input_image).save(imgByteArr, format="PNG")
    else:
        Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
    imgByteArr = imgByteArr.getvalue()

    if gallery:
        path = oss_path_img_gallery
    else:
        path = oss_path

    bucket.put_object(path + "/" + name, imgByteArr)  # data为数据,可以是图片
    ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24)  # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
    del imgByteArr
    return ret


def upload_json_string_2_oss(jsonStr, name="cache.txt", gallery=False):
    if gallery:
        path = oss_path_img_gallery
    else:
        path = oss_path

    bucket.put_object(path + "/" + name, bytes(jsonStr, "utf-8"))  # data为数据
    ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24)  # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
    return ret


def upload_preprocess(pil_base_image_rgba, pil_layout_image_dict, pil_style_image_dict, pil_color_image_dict,
                      pil_fg_mask):
    np_out_base_image = np_out_layout_image = np_out_style_image = np_out_color_image = None

    if pil_base_image_rgba is not None:
        np_fg_image = np.array(pil_base_image_rgba)[..., :3]
        np_fg_mask = np.expand_dims(np.array(pil_fg_mask).astype(float), axis=-1) / 255.
        np_fg_mask = np_fg_mask * 0.5 + 0.5
        np_out_base_image = (np_fg_image * np_fg_mask + (1 - np_fg_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                               255).astype(
            np.uint8)

    if pil_layout_image_dict is not None:
        np_layout_image = np.array(pil_layout_image_dict["image"].convert("RGBA"))
        np_layout_image, np_layout_alpha = np_layout_image[..., :3], np_layout_image[..., 3]
        np_layout_mask = np.array(pil_layout_image_dict["mask"].convert("L"))
        np_layout_mask = ((np_layout_alpha > 127) * (np_layout_mask < 127)).astype(float)[..., None]
        np_layout_mask = np_layout_mask * 0.5 + 0.5
        np_out_layout_image = (
                np_layout_image * np_layout_mask + (1 - np_layout_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                              255).astype(
            np.uint8)

    if pil_style_image_dict is not None:
        np_style_image = np.array(pil_style_image_dict["image"].convert("RGBA"))
        np_style_image, np_style_alpha = np_style_image[..., :3], np_style_image[..., 3]
        np_style_mask = np.array(pil_style_image_dict["mask"].convert("L"))
        np_style_mask = ((np_style_alpha > 127) * (np_style_mask < 127)).astype(float)[..., None]
        np_style_mask = np_style_mask * 0.5 + 0.5
        np_out_style_image = (
                np_style_image * np_style_mask + (1 - np_style_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                           255).astype(
            np.uint8)

    if pil_color_image_dict is not None:
        np_color_image = np.array(pil_color_image_dict["image"].convert("RGBA"))
        np_color_image, np_color_alpha = np_color_image[..., :3], np_color_image[..., 3]
        np_color_mask = np.array(pil_color_image_dict["mask"].convert("L"))
        np_color_mask = ((np_color_alpha > 127) * (np_color_mask < 127)).astype(float)[..., None]
        np_color_mask = np_color_mask * 0.5 + 0.5
        np_out_color_image = (
                np_color_image * np_color_mask + (1 - np_color_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                           255).astype(
            np.uint8)

    return np_out_base_image, np_out_layout_image, np_out_style_image, np_out_color_image


def pad_image(image, target_size):
    iw, ih = image.size  # 原始图像的尺寸
    w, h = target_size  # 目标图像的尺寸
    scale = min(w / iw, h / ih)  # 转换的最小比例
    # 保证长或宽,至少一个符合目标图像的尺寸 0.5保证四舍五入
    nw = int(iw * scale + 0.5)
    nh = int(ih * scale + 0.5)
    image = image.resize((nw, nh), Image.BICUBIC)  # 更改图像尺寸,双立法插值效果很好
    new_image = Image.new('RGB', target_size, (255, 255, 255))  # 生成白色图像
    new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))  # 将图像填充为中间图像,两侧为黑色的样式
    return new_image


def add_text(image, text):
    w, h = image.size
    text_image = image.copy()
    text_image_draw = ImageDraw.Draw(text_image)

    ttf = ImageFont.truetype("assets/ttf/AlibabaPuHuiTi-2-55-Regular.ttf", int(h / 10))
    left, top, right, bottom = ttf.getbbox(text)
    text_image_draw.rectangle((0, 0, right + left, bottom + top), fill=(255, 255, 255))

    image = Image.blend(image, text_image, 0.5)

    image_draw = ImageDraw.Draw(image)
    fillColor = (0, 0, 0, 255)  # 文字颜色:黑色
    pos = (0, 0)  # 文本左上角位置 (离左边界距离, 离上边界距离)
    image_draw.text(pos, text, font=ttf, fill=fillColor)
    return image.convert("RGB")


def compose_image(image_list, text_list, pil_size, nrow, ncol):
    w, h = pil_size  # 每张小图片大小

    if len(image_list) > nrow * ncol:
        raise ValueError("合成图片的参数和要求的数量不能匹配!")

    assert len(image_list) == len(text_list)
    new_image_list = []
    new_text_list = []
    for image, text in zip(image_list, text_list):
        if image is not None:
            new_image_list.append(image)
            new_text_list.append(text)
    if len(new_image_list) == 1:
        ncol = nrow = 1
    to_image = Image.new('RGB', (ncol * w, nrow * h), (255, 255, 255))  # 创建一个新图
    for y in range(1, nrow + 1):
        for x in range(1, ncol + 1):
            if ncol * (y - 1) + x - 1 < len(new_image_list):
                from_image = new_image_list[ncol * (y - 1) + x - 1].resize((w, h), Image.BICUBIC)
                from_text = new_text_list[ncol * (y - 1) + x - 1]
                if from_text is not None:
                    from_image = add_text(from_image, from_text)
                to_image.paste(from_image, ((x - 1) * w, (y - 1) * h))
    return to_image


def split_text_lines(text, max_w, ttf):
    text_split_lines = []
    text_h = 0
    if text != "":
        line_start = 0
        while line_start < len(text):
            line_count = 0
            _, _, right, bottom = ttf.getbbox(text[line_start: line_start + line_count + 1])
            while right < max_w and line_count < len(text):
                line_count += 1
                _, _, right, bottom = ttf.getbbox(text[line_start: line_start + line_count + 1])
            text_split_lines.append(text[line_start:line_start + line_count])
            text_h += bottom
            line_start += line_count
    return text_split_lines, text_h


def add_prompt(image, prompt, negative_prompt):
    if prompt == "" and negative_prompt == "":
        return image
    if prompt != "":
        prompt = "Prompt: " + prompt
    if negative_prompt != "":
        negative_prompt = "Negative prompt: " + negative_prompt

    w, h = image.size

    ttf = ImageFont.truetype("assets/ttf/AlibabaPuHuiTi-2-55-Regular.ttf", int(h / 20))

    prompt_split_lines, prompt_h = split_text_lines(prompt, w, ttf)
    negative_prompt_split_lines, negative_prompt_h = split_text_lines(negative_prompt, w, ttf)
    text_h = prompt_h + negative_prompt_h
    text = "\n".join(prompt_split_lines + negative_prompt_split_lines)
    text_image = Image.new(image.mode, (w, text_h), color=(255, 255, 255))
    text_image_draw = ImageDraw.Draw(text_image)
    text_image_draw.text((0, 0), text, font=ttf, fill=(0, 0, 0))

    out_image = Image.new(image.mode, (w, h + text_h), color=(255, 255, 255))
    out_image.paste(image, (0, 0))
    out_image.paste(text_image, (0, h))

    return out_image


def merge_images(np_fg_image, np_layout_image, np_style_image, np_color_image, np_res_image, prompt, negative_prompt):
    pil_res_image = Image.fromarray(np_res_image)

    w, h = pil_res_image.size
    pil_fg_image = None if np_fg_image is None else pad_image(Image.fromarray(np_fg_image), (w, h))
    pil_layout_image = None if np_layout_image is None else pad_image(Image.fromarray(np_layout_image), (w, h))
    pil_style_image = None if np_style_image is None else pad_image(Image.fromarray(np_style_image), (w, h))
    pil_color_image = None if np_color_image is None else pad_image(Image.fromarray(np_color_image), (w, h))

    input_images = [pil_layout_image, pil_style_image, pil_color_image, pil_fg_image]
    input_texts = ['Layout', 'Style', 'Color', 'Subject']
    input_compose_image = compose_image(input_images, input_texts, (w, h), nrow=2, ncol=2)
    input_compose_image = input_compose_image.resize((w, h), Image.BICUBIC)
    output_compose_image = compose_image([input_compose_image, pil_res_image], [None, None], (w, h), nrow=1,
                                         ncol=2)
    output_compose_image = add_prompt(output_compose_image, prompt, negative_prompt)

    output_compose_image = np.array(output_compose_image)

    return output_compose_image