File size: 5,664 Bytes
a637d5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os, glob, json, base64, re
from io import BytesIO
from PIL import Image, PngImagePlugin
from image_process import image_canny,image_pose_mask,image_pose_mask_numpy
from generate_img import generate_image, generate_image_sketch

_SAVED_POSES_DIR = ''

image_cache = dict()

def set_save_dir(dir: str):
    global _SAVED_POSES_DIR
    _SAVED_POSES_DIR = os.path.realpath(str(dir))

def get_save_dir():
    assert len(_SAVED_POSES_DIR) != 0
    return _SAVED_POSES_DIR

def get_saved_path(name: str):
    #return os.path.realpath(os.path.join(get_save_dir(), name))
    return os.path.join(get_save_dir(), name)

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

def sorted_glob(path):
    return sorted(glob.glob(path), key=natural_keys)

def name2path(name: str):
    if not isinstance(name, str):
        raise ValueError(f'str object expected, but {type(name)}')
    
    if len(name) == 0:
        raise ValueError(f'empty name')
    
    if '.' in name or '/' in name or '\\' in name:
        raise ValueError(f'invalid name: {name}')
    
    path = get_saved_path(f'{name}.png')
    if not path.startswith(get_save_dir()):
        raise ValueError(f'invalid name: {name}')
    
    return path

def saved_poses():
    for path in sorted_glob(os.path.join(get_save_dir(), '*.png')):
        yield Image.open(path)

def all_poses():
    for img in saved_poses():
        buffer = BytesIO()
        img.save(buffer, format='png')
        
        if not hasattr(img, 'text'):
            continue
        
        pose_dict = {
            'name': img.text['name'],                   # type: ignore
            'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
            'screen': json.loads(img.text['screen']),   # type: ignore
            'camera': json.loads(img.text['camera']),   # type: ignore
            'joints': json.loads(img.text['joints']),   # type: ignore
        }
        
        yield pose_dict

def save_pose(data: dict):
    print(data)
    name = data['name']
    screen = data['screen']
    camera = data['camera']
    joints = data['joints']
    
    info = PngImagePlugin.PngInfo()
    info.add_text('name', name)
    info.add_text('screen', json.dumps(screen))
    info.add_text('camera', json.dumps(camera))
    info.add_text('joints', json.dumps(joints))
    
    filepath = name2path(name)
    
    image = Image.open(BytesIO(base64.b64decode(data['image'][len('data:image/png;base64,'):])))
    unit = max(image.width, image.height)
    mx, my = (unit - image.width) // 2, (unit - image.height) // 2
    canvas = Image.new('RGB', (unit, unit), color=(68, 68, 68))
    canvas.paste(image, (mx, my))
    image = canvas.resize((canvas.width//4, canvas.height//4))
    
    image.save(filepath, pnginfo=info)

def delete_pose(name: str):
    filepath = name2path(name)
    os.remove(filepath)

def load_pose(name: str):
    filepath = name2path(name)
    img = Image.open(filepath)
    
    buffer = BytesIO()
    img.save(buffer, format='png')
    
    if not hasattr(img, 'text'):
        raise ValueError(f'not pose data: {filepath}')
    
    pose_dict = {
        'name': img.text['name'],                   # type: ignore
        'image': base64.b64encode(buffer.getvalue()).decode('ascii'),
        'screen': json.loads(img.text['screen']),   # type: ignore
        'camera': json.loads(img.text['camera']),   # type: ignore
        'joints': json.loads(img.text['joints']),   # type: ignore
    }
    
    return pose_dict
def base64_PIL(data:str):
    return Image.open(BytesIO(base64.b64decode(data)))

def PIL_base64(data):
    return base64.b64encode(data.tobytes()).decode('utf-8')

def resizeImg(image1,image2):
    width1, height1 = image1.size
    # 使用图像1的宽高来resize图像2
    image2_resized = image2.resize((width1, height1))
    # 返回resize后的图像2
    return image2_resized

# def get_img(data):
#     #执行逻辑
#     if (data[0]):
#         bgImgBase64 = data[0]['bgImg'][len('data:image/png;base64,'):]
#         maskImgBase64 = data[0]['maskImg'][len('data:image/png;base64,'):]
#         image_cache['bgImgBase64'] = bgImgBase64
#         image_cache['maskImgBase64'] = maskImgBase64
#     return 'success'

def generate_img(data, image_prompt, image_n_prompt):
    if (data[0]):
        bg_img = data[0]['bgImg'][len('data:image/png;base64,'):]
        mask_img_openpose = data[0]['maskImg'][len('data:image/png;base64,'):]
        print((len(bg_img), len(mask_img_openpose)))
        print((image_prompt, image_n_prompt))

        maskImg_base64 = image_pose_mask(mask_img_openpose)

        controlnet_img_pil = base64_PIL(mask_img_openpose)
        bg_img_pil = base64_PIL(bg_img)
        mask_img_pil = base64_PIL(maskImg_base64)
        bg_img_pil = resizeImg(mask_img_pil, bg_img_pil)

        img = generate_image(image_prompt, image_n_prompt, controlnet_img_pil, bg_img_pil, mask_img_pil)

        return [img]
        # return [mask_img_pil]
    #openpose流程

    return None


def get_image_sketch(image, image_prompt, image_n_prompt):
    img_origin_numpy = image['image']
    img_sketch_numpy = image['mask']
    # print(type(img_origin))
    # print(type(PIL_base64(Image.fromarray(img_masj))))
    mask_pil = base64_PIL(image_pose_mask_numpy(img_sketch_numpy))
    img_origin_pil = Image.fromarray(img_origin_numpy)
    sketch_pil = Image.fromarray(img_sketch_numpy)
    img = generate_image_sketch(image_prompt, image_n_prompt, sketch_pil, img_origin_pil, mask_pil)

    return img
    # return [mask_pil,img_origin_pil,Image.fromarray(img_masj)]