File size: 4,870 Bytes
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from data.prefix_instruction import get_image_prompt, get_task_instruction, get_layout_instruction, get_content_instruction
import random
from PIL import Image
from .gradio_tasks import dense_prediction_data


style_transfer = [
    dict(
        name='Style Transfer', 
        image_type=["target", "style_source", "style_target"]),
]
style_transfer_text = [[x['name']] for x in style_transfer]


style_condition_fusion = [
    dict(
        name='Canny+Style to Image', 
        image_type=["canny", "style_source", "style_target"]),
    dict(
        name='Depth+Style to Image', 
        image_type=["depth", "style_source", "style_target"]),
    dict(
        name='Hed+Style to Image', 
        image_type=["hed", "style_source", "style_target"]),
    dict(
        name='Normal+Style to Image', 
        image_type=["normal", "style_source", "style_target"]),
    dict(
        name='Pose+Style to Image', 
        image_type=["openpose", "style_source", "style_target"]),
    dict(
        name='SAM2+Style to Image', 
        image_type=["sam2_mask", "style_source", "style_target"]),
    dict(
        name='Mask+Style to Image', 
        image_type=["mask", "style_source", "style_target"]),
]
style_condition_fusion_text = [[x['name']] for x in style_condition_fusion]


def process_style_transfer_tasks(x):
    for task in style_transfer:
        if task['name'] == x[0]:
            image_type = task['image_type']
            image_prompt_list = [get_image_prompt(x)[0] for x in image_type]
            image_prompt_list = [f"[IMAGE{idx+1}] {image_prompt}" for idx, image_prompt in enumerate(image_prompt_list)]
            condition_prompt = ", ".join(image_prompt_list[:-1])
            target_prompt = image_prompt_list[-1]
            task_prompt = get_task_instruction(condition_prompt, target_prompt)

            # sample examples
            valid_data = [x for x in dense_prediction_data if all([(x.get(t, None) is not None and os.path.exists(x[t])) for t in image_type])]
            n_samples = random.randint(2, min(len(valid_data), 3))
            images = random.sample(valid_data, k=n_samples)
            rets = []
            for image in images:
                for t in image_type:
                    if t == "style_source":
                        target = Image.open(image["style_target"])
                        source = Image.open(image[t])
                        source = source.resize(target.size)
                        rets.append(source)
                    else:
                        rets.append(Image.open(image[t]))
            
            content_prompt = ""

            grid_h = n_samples
            grid_w = len(image_type)
            mask = task.get('mask', [0 for _ in range(grid_w - 1)] + [1])
            layout_prompt = get_layout_instruction(grid_w, grid_h)

            upsampling_noise = None
            steps = None
            outputs = [mask, grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + rets
            break

    return outputs


def process_style_condition_fusion_tasks(x):
    for task in style_condition_fusion:
        if task['name'] == x[0]:
            image_type = task['image_type']
            image_prompt_list = [get_image_prompt(x)[0] for x in image_type]
            image_prompt_list = [f"[IMAGE{idx+1}] {image_prompt}" for idx, image_prompt in enumerate(image_prompt_list)]
            condition_prompt = ", ".join(image_prompt_list[:-1])
            target_prompt = image_prompt_list[-1]
            task_prompt = get_task_instruction(condition_prompt, target_prompt)

            # sample examples
            valid_data = [x for x in dense_prediction_data if all([(x.get(t, None) is not None and os.path.exists(x[t])) for t in image_type])]
            x = dense_prediction_data[0]
            n_samples = random.randint(2, min(len(valid_data), 3))
            images = random.sample(valid_data, k=n_samples)
            rets = []
            for image in images:
                for t in image_type:
                    if t == "style_source":
                        target = Image.open(image["style_target"])
                        source = Image.open(image[t])
                        source = source.resize(target.size)
                        rets.append(source)
                    else:
                        rets.append(Image.open(image[t]))   

            content_prompt = ""

            grid_h = n_samples
            grid_w = len(image_type)
            mask = task.get('mask', [0 for _ in range(grid_w - 1)] + [1])
            layout_prompt = get_layout_instruction(grid_w, grid_h)  

            upsampling_noise = None
            steps = None
            outputs = [mask, grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + rets
            break

    return outputs