File size: 12,809 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import json
import logging
import os.path
import shutil
import time
import zipfile
from textwrap import dedent
from typing import Optional

import numpy as np
import pandas as pd
from imgutils.metrics import ccip_extract_feature, ccip_batch_same
from tqdm.auto import tqdm
from waifuc.source import LocalSource

try:
    import torch
except (ImportError, ModuleNotFoundError):
    torch = None

from .convert import convert_to_webui_lora
from .steps import find_steps_in_workdir
from ..dataset import load_dataset_for_character
from ..dataset.tags import sort_draw_names
from ..infer.draw import _DEFAULT_INFER_MODEL
from ..infer.draw import draw_with_workdir
from ..utils import repr_tags, load_tags_from_directory

KNOWN_MODEL_HASHES = {
    'AIARTCHAN/anidosmixV2': 'EB49192009',
    'stablediffusionapi/anything-v5': None,
    'stablediffusionapi/cetusmix': 'B42B09FF12',
    'Meina/MeinaMix_V10': 'D967BCAE4A',
    'Meina/MeinaMix_V11': '54EF3E3610',
    'Lykon/DreamShaper': 'C33104F6',
    'digiplay/majicMIX_realistic_v6': 'EBDB94D4',
    'stablediffusionapi/abyssorangemix2nsfw': 'D6992792',
    'AIARTCHAN/expmixLine_v2': 'D91B18D1',
    'Yntec/CuteYuki2': 'FBE372BA',
    'stablediffusionapi/counterfeit-v30': '12047227',
    'jzli/XXMix_9realistic-v4': '5D22F204',
    'stablediffusionapi/flat-2d-animerge': 'F279CF76',
    'redstonehero/cetusmix_v4': '838408E0',
    'Meina/Unreal_V4.1': '0503BFAD',
    'Meina/MeinaHentai_V4': '39C0C3B6',
    'Meina/MeinaPastel_V6': 'DA1D535E',
    'KBlueLeaf/kohaku-v4-rev1.2': '87F9E45D',
    'stablediffusionapi/night-sky-yozora-sty': 'D31F707A',
}

EXPORT_MARK = 'v1.4.1'

_GITLFS = dedent("""
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
""").strip()


def export_workdir(workdir: str, export_dir: str, n_repeats: int = 2,
                   pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
                   image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
                   lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
                   model_hash: Optional[str] = None, ds_repo: Optional[str] = None):
    name, steps = find_steps_in_workdir(workdir)
    logging.info(f'Starting export trained artifacts of {name!r}, with steps: {steps!r}')
    model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model, None)
    if model_hash:
        logging.info(f'Model hash {model_hash!r} detected for model {pretrained_model!r}.')

    if os.path.exists(os.path.join(workdir, 'meta.json')):
        with open(os.path.join(workdir, 'meta.json'), 'r', encoding='utf-8') as f:
            dataset_info = json.load(f)['dataset']
    else:
        dataset_info = None

    ds_repo = ds_repo or f'AppleHarem/{name}'
    ds_size = (384, 512) if not dataset_info or not dataset_info['type'] else dataset_info['type']
    logging.info(f'Loading dataset {ds_repo!r}, {ds_size!r} ...')
    with load_dataset_for_character(ds_repo, ds_size) as (ch, ds_dir):
        core_tags, _ = load_tags_from_directory(ds_dir)
        ds_source = LocalSource(ds_dir)
        ds_feats = []
        for item in tqdm(list(ds_source), desc='Extract Dataset Feature'):
            ds_feats.append(ccip_extract_feature(item.image))

    d_names = set()
    all_drawings = {}
    nsfw_count = {}
    all_scores = {}
    all_scores_lst = []
    for step in steps:
        logging.info(f'Exporting for {name}-{step} ...')
        step_dir = os.path.join(export_dir, f'{step}')
        os.makedirs(step_dir, exist_ok=True)

        preview_dir = os.path.join(step_dir, 'previews')
        os.makedirs(preview_dir, exist_ok=True)

        while True:
            try:
                drawings = draw_with_workdir(
                    workdir, model_steps=step, n_repeats=n_repeats,
                    pretrained_model=pretrained_model,
                    width=image_width, height=image_height, infer_steps=infer_steps,
                    lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method,
                    model_hash=model_hash,
                )
            except RuntimeError:
                n_repeats += 1
            else:
                break

        all_image_files = []
        image_feats = []
        for draw in drawings:
            img_file = os.path.join(preview_dir, f'{draw.name}.png')
            image_feats.append(ccip_extract_feature(draw.image))
            draw.image.save(img_file, pnginfo=draw.pnginfo)
            all_image_files.append(img_file)

            with open(os.path.join(preview_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f:
                print(draw.preview_info, file=f)
            d_names.add(draw.name)
            all_drawings[(draw.name, step)] = draw
            if not draw.sfw:
                nsfw_count[draw.name] = nsfw_count.get(draw.name, 0) + 1

        pt_file = os.path.join(workdir, 'ckpts', f'{name}-{step}.pt')
        unet_file = os.path.join(workdir, 'ckpts', f'unet-{step}.safetensors')
        text_encoder_file = os.path.join(workdir, 'ckpts', f'text_encoder-{step}.safetensors')
        raw_dir = os.path.join(step_dir, 'raw')
        os.makedirs(raw_dir, exist_ok=True)
        shutil.copyfile(pt_file, os.path.join(raw_dir, os.path.basename(pt_file)))
        shutil.copyfile(unet_file, os.path.join(raw_dir, os.path.basename(unet_file)))
        shutil.copyfile(text_encoder_file, os.path.join(raw_dir, os.path.basename(text_encoder_file)))

        shutil.copyfile(pt_file, os.path.join(step_dir, f'{name}.pt'))
        convert_to_webui_lora(unet_file, text_encoder_file, os.path.join(step_dir, f'{name}.safetensors'))
        with zipfile.ZipFile(os.path.join(step_dir, f'{name}.zip'), 'w') as zf:
            zf.write(os.path.join(step_dir, f'{name}.pt'), f'{name}.pt')
            zf.write(os.path.join(step_dir, f'{name}.safetensors'), f'{name}.safetensors')
            for img_file in all_image_files:
                zf.write(img_file, os.path.basename(img_file))

        same_matrix = ccip_batch_same([*image_feats, *ds_feats])
        score = same_matrix[:len(image_feats), len(image_feats):].mean()
        all_scores[step] = score
        all_scores_lst.append(score)
        logging.info(f'Score of step {step} is {score}.')

    lst_scores = np.array(all_scores_lst)
    lst_steps = np.array(steps)
    if dataset_info and 'size' in dataset_info:
        min_best_steps = 6 * dataset_info['size']
        _lst_scores = lst_scores[lst_steps >= min_best_steps]
        _lst_steps = lst_steps[lst_steps >= min_best_steps]
        if _lst_scores.shape[0] > 0:
            lst_steps, lst_scores = _lst_steps, _lst_scores

    best_idx = np.argmax(lst_scores)
    best_step = lst_steps[best_idx].item()
    nsfw_ratio = {name: count * 1.0 / len(steps) for name, count in nsfw_count.items()}
    with open(os.path.join(export_dir, 'meta.json'), 'w', encoding='utf-8') as f:
        json.dump({
            'name': name,
            'steps': steps,
            'mark': EXPORT_MARK,
            'time': time.time(),
            'dataset': dataset_info,
            'scores': [
                {
                    'step': step,
                    'score': score,
                } for step, score in sorted(all_scores.items())
            ],
            'best_step': best_step,
        }, f, ensure_ascii=False, indent=4)
    with open(os.path.join(export_dir, '.gitattributes'), 'w', encoding='utf-8') as f:
        print(_GITLFS, file=f)
    with open(os.path.join(export_dir, 'README.md'), 'w', encoding='utf-8') as f:
        print(f'# Lora of {name}', file=f)
        print('', file=f)

        print('This model is trained with [HCP-Diffusion](https://github.com/7eu7d7/HCP-Diffusion). '
              'And the auto-training framework is maintained by '
              '[DeepGHS Team](https://huggingface.co/deepghs).'
              'And the WebUI Panel provid by [LittleAppleWebUI](https://github.com/LittleApple-fp16/LittleAppleWebUI)', file=f)
        print('', file=f)

        print('The base model used during training is [NAI](https://huggingface.co/deepghs/animefull-latest), '
              f'and the base model used for generating preview images is '
              f'[{pretrained_model}](https://huggingface.co/{pretrained_model}).', file=f)
        print('', file=f)

        print(f'After downloading the pt and safetensors files for the specified step, '
              f'you need to use them simultaneously. The pt file will be used as an embedding, '
              f'while the safetensors file will be loaded for Lora.', file=f)
        print('', file=f)
        print(f'For example, if you want to use the model from step {best_step}, '
              f'you need to download `{best_step}/{name}.pt` as the embedding and '
              f'`{best_step}/{name}.safetensors` for loading Lora. '
              f'By using both files together, you can generate images for the desired characters.', file=f)
        print('', file=f)

        print(dedent(f"""
**The best step we recommend is {best_step}**, with the score of {all_scores[best_step]:.3f}. The trigger words are:
1. `{name}`
2. `{repr_tags([key for key, _ in sorted(core_tags.items(), key=lambda x: -x[1])])}`
        """).strip(), file=f)
        print('', file=f)

        print(dedent("""
For the following groups, it is not recommended to use this model and we express regret:
1. Individuals who cannot tolerate any deviations from the original character design, even in the slightest detail.
2. Individuals who are facing the application scenarios with high demands for accuracy in recreating character outfits.
3. Individuals who cannot accept the potential randomness in AI-generated images based on the Stable Diffusion algorithm.
4. Individuals who are not comfortable with the fully automated process of training character models using LoRA, or those who believe that training character models must be done purely through manual operations to avoid disrespecting the characters.
5. Individuals who finds the generated image content offensive to their values.
        """).strip(), file=f)
        print('', file=f)

        print(f'These are available steps:', file=f)
        print('', file=f)

        d_names = sort_draw_names(list(d_names))
        columns = ['Steps', 'Score', 'Download', *d_names]
        t_data = []

        for step in steps[::-1]:
            d_mds = []
            for dname in d_names:
                file = os.path.join(str(step), 'previews', f'{dname}.png')
                if (dname, step) in all_drawings:
                    if nsfw_ratio.get(dname, 0.0) < 0.35:
                        d_mds.append(f'![{dname}-{step}]({file})')
                    else:
                        d_mds.append(f'[<NSFW, click to see>]({file})')
                else:
                    d_mds.append('')

            t_data.append((
                str(step) if step != best_step else f'**{step}**',
                f'{all_scores[step]:.3f}' if step != best_step else f'**{all_scores[step]:.3f}**',
                f'[Download]({step}/{name}.zip)' if step != best_step else f'[**Download**]({step}/{name}.zip)',
                *d_mds,
            ))

        df = pd.DataFrame(columns=columns, data=t_data)
        print(df.to_markdown(index=False), file=f)
        print('', file=f)