File size: 9,569 Bytes
2a00960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import io
import math
import os
import sys
from collections import defaultdict

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

from scepter.modules.data.dataset.base_dataset import BaseDataset
from scepter.modules.data.dataset.registry import DATASETS
from scepter.modules.transform.io import pillow_convert
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS

Image.MAX_IMAGE_PIXELS = None

@DATASETS.register_class()
class ACEDemoDataset(BaseDataset):
    para_dict = {
        'MS_DATASET_NAME': {
            'value': '',
            'description': 'Modelscope dataset name.'
        },
        'MS_DATASET_NAMESPACE': {
            'value': '',
            'description': 'Modelscope dataset namespace.'
        },
        'MS_DATASET_SUBNAME': {
            'value': '',
            'description': 'Modelscope dataset subname.'
        },
        'MS_DATASET_SPLIT': {
            'value': '',
            'description':
            'Modelscope dataset split set name, default is train.'
        },
        'MS_REMAP_KEYS': {
            'value':
            None,
            'description':
            'Modelscope dataset header of list file, the default is Target:FILE; '
            'If your file is not this header, please set this field, which is a map dict.'
            "For example, { 'Image:FILE': 'Target:FILE' } will replace the filed Image:FILE to Target:FILE"
        },
        'MS_REMAP_PATH': {
            'value':
            None,
            'description':
            'When modelscope dataset name is not None, that means you use the dataset from modelscope,'
            ' default is None. But if you want to use the datalist from modelscope and the file from '
            'local device, you can use this field to set the root path of your images. '
        },
        'TRIGGER_WORDS': {
            'value':
            '',
            'description':
            'The words used to describe the common features of your data, especially when you customize a '
            'tuner. Use these words you can get what you want.'
        },
        'HIGHLIGHT_KEYWORDS': {
            'value':
            '',
            'description':
            'The keywords you want to highlight in prompt, which will be replace by <HIGHLIGHT_KEYWORDS>.'
        },
        'KEYWORDS_SIGN': {
            'value':
            '',
            'description':
            'The keywords sign you want to add, which is like <{HIGHLIGHT_KEYWORDS}{KEYWORDS_SIGN}>'
        },
    }

    def __init__(self, cfg, logger=None):
        super().__init__(cfg=cfg, logger=logger)
        from modelscope import MsDataset
        from modelscope.utils.constant import DownloadMode
        ms_dataset_name = cfg.get('MS_DATASET_NAME', None)
        ms_dataset_namespace = cfg.get('MS_DATASET_NAMESPACE', None)
        ms_dataset_subname = cfg.get('MS_DATASET_SUBNAME', None)
        ms_dataset_split = cfg.get('MS_DATASET_SPLIT', 'train')
        ms_remap_keys = cfg.get('MS_REMAP_KEYS', None)
        ms_remap_path = cfg.get('MS_REMAP_PATH', None)

        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
        self.max_aspect_ratio = cfg.get('MAX_ASPECT_RATIO', 4)
        self.d = cfg.get('DOWNSAMPLE_RATIO', 16)
        self.replace_style = cfg.get('REPLACE_STYLE', False)
        self.trigger_words = cfg.get('TRIGGER_WORDS', '')
        self.replace_keywords = cfg.get('HIGHLIGHT_KEYWORDS', '')
        self.keywords_sign = cfg.get('KEYWORDS_SIGN', '')
        self.add_indicator = cfg.get('ADD_INDICATOR', False)
        # Use modelscope dataset
        if not ms_dataset_name:
            raise ValueError(
                'Your must set MS_DATASET_NAME as modelscope dataset or your local dataset orignized '
                'as modelscope dataset.')
        if FS.exists(ms_dataset_name):
            ms_dataset_name = FS.get_dir_to_local_dir(ms_dataset_name)
            self.ms_dataset_name = ms_dataset_name
            # ms_remap_path = ms_dataset_name
        try:
            self.data = MsDataset.load(str(ms_dataset_name),
                                       namespace=ms_dataset_namespace,
                                       subset_name=ms_dataset_subname,
                                       split=ms_dataset_split)
        except Exception:
            self.logger.info(
                "Load Modelscope dataset failed, retry with download_mode='force_redownload'."
            )
            try:
                self.data = MsDataset.load(
                    str(ms_dataset_name),
                    namespace=ms_dataset_namespace,
                    subset_name=ms_dataset_subname,
                    split=ms_dataset_split,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD)
            except Exception as sec_e:
                raise ValueError(f'Load Modelscope dataset failed {sec_e}.')
        if ms_remap_keys:
            self.data = self.data.remap_columns(ms_remap_keys.get_dict())

        if ms_remap_path:

            def map_func(example):
                return {
                    k: os.path.join(ms_remap_path, v)
                    if k.endswith(':FILE') else v
                    for k, v in example.items()
                }

            self.data = self.data.ds_instance.map(map_func)

        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        if self.mode == 'train':
            return sys.maxsize
        else:
            return len(self.data)

    def _get(self, index: int):
        current_data = self.data[index % len(self.data)]

        tar_image_path = current_data.get('Target:FILE', '')
        src_image_path = current_data.get('Source:FILE', '')

        style = current_data.get('Style', '')
        prompt = current_data.get('Prompt', current_data.get('prompt', ''))
        if self.replace_style and not style == '':
            prompt = prompt.replace(style, f'<{self.keywords_sign}>')

        elif not self.replace_keywords.strip() == '':
            prompt = prompt.replace(
                self.replace_keywords,
                '<' + self.replace_keywords + f'{self.keywords_sign}>')

        if not self.trigger_words == '':
            prompt = self.trigger_words.strip() + ' ' + prompt

        src_image = self.load_image(self.ms_dataset_name,
                                    src_image_path,
                                    cvt_type='RGB')
        tar_image = self.load_image(self.ms_dataset_name,
                                    tar_image_path,
                                    cvt_type='RGB')
        src_image = self.image_preprocess(src_image)
        tar_image = self.image_preprocess(tar_image)

        tar_image = self.transforms(tar_image)
        src_image = self.transforms(src_image)
        src_mask = torch.ones_like(src_image[[0]])
        tar_mask = torch.ones_like(tar_image[[0]])
        if self.add_indicator:
            if '{image}' not in prompt:
                prompt = '{image}, ' + prompt

        return {
            'edit_image': [src_image],
            'edit_image_mask': [src_mask],
            'image': tar_image,
            'image_mask': tar_mask,
            'prompt': [prompt],
        }

    def load_image(self, prefix, img_path, cvt_type=None):
        if img_path is None or img_path == '':
            return None
        img_path = os.path.join(prefix, img_path)
        with FS.get_object(img_path) as image_bytes:
            image = Image.open(io.BytesIO(image_bytes))
            if cvt_type is not None:
                image = pillow_convert(image, cvt_type)
        return image

    def image_preprocess(self,
                         img,
                         size=None,
                         interpolation=InterpolationMode.BILINEAR):
        H, W = img.height, img.width
        if H / W > self.max_aspect_ratio:
            img = T.CenterCrop((self.max_aspect_ratio * W, W))(img)
        elif W / H > self.max_aspect_ratio:
            img = T.CenterCrop((H, self.max_aspect_ratio * H))(img)

        if size is None:
            # resize image for max_seq_len, while keep the aspect ratio
            H, W = img.height, img.width
            scale = min(
                1.0,
                math.sqrt(self.max_seq_len / ((H / self.d) * (W / self.d))))
            rH = int(
                H * scale) // self.d * self.d  # ensure divisible by self.d
            rW = int(W * scale) // self.d * self.d
        else:
            rH, rW = size
        img = T.Resize((rH, rW), interpolation=interpolation,
                       antialias=True)(img)
        return np.array(img, dtype=np.uint8)

    @staticmethod
    def get_config_template():
        return dict_to_yaml('DATASet',
                            __class__.__name__,
                            ACEDemoDataset.para_dict,
                            set_name=True)

    @staticmethod
    def collate_fn(batch):
        collect = defaultdict(list)
        for sample in batch:
            for k, v in sample.items():
                collect[k].append(v)

        new_batch = dict()
        for k, v in collect.items():
            if all([i is None for i in v]):
                new_batch[k] = None
            else:
                new_batch[k] = v

        return new_batch