File size: 4,073 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import json
import logging
import os
from typing import Optional

from hbutils.system import TemporaryDirectory
from huggingface_hub import hf_hub_url
from tqdm.auto import tqdm

from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir
from ..dataset import save_recommended_tags
from ..utils import get_hf_fs, download_file


def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2,
                      pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
                      image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
                      lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
                      model_hash: Optional[str] = None):
    from ..publish.export import KNOWN_MODEL_HASHES
    model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model)
    os.makedirs(export_dir, exist_ok=True)

    while True:
        try:
            drawings = draw_with_workdir(
                workdir, model_steps=step, n_repeats=n_repeats,
                pretrained_model=pretrained_model,
                width=image_width, height=image_height, infer_steps=infer_steps,
                lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method,
                model_hash=model_hash,
            )
        except RuntimeError:
            n_repeats += 1
        else:
            break

    all_image_files = []
    for draw in drawings:
        img_file = os.path.join(export_dir, f'{draw.name}.png')
        draw.image.save(img_file, pnginfo=draw.pnginfo)
        all_image_files.append(img_file)

        with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f:
            print(draw.preview_info, file=f)


def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2,
                   pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
                   image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
                   lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
                   model_hash: Optional[str] = None):
    from ..publish import find_steps_in_workdir

    hf_fs = get_hf_fs()
    if not hf_fs.exists(f'{repository}/meta.json'):
        raise ValueError(f'Invalid repository or no model found - {repository!r}.')

    logging.info(f'Model repository {repository!r} found.')
    meta = json.loads(hf_fs.read_text(f'{repository}/meta.json'))
    step = step or meta['best_step']
    logging.info(f'Using step {step} ...')

    with TemporaryDirectory() as workdir:
        logging.info('Downloading models ...')
        for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')):
            rel_file = os.path.relpath(f, repository)
            local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
            if os.path.dirname(local_file):
                os.makedirs(os.path.dirname(local_file), exist_ok=True)
            download_file(
                hf_hub_url(repository, filename=rel_file),
                local_file
            )

        logging.info(f'Regenerating tags for {workdir!r} ...')
        pt_name, _ = find_steps_in_workdir(workdir)
        game_name = pt_name.split('_')[-1]
        name = '_'.join(pt_name.split('_')[:-1])

        from gchar.games.dispatch.access import GAME_CHARS
        if game_name in GAME_CHARS:
            ch_cls = GAME_CHARS[game_name]
            ch = ch_cls.get(name)
        else:
            ch = None

        if ch is None:
            source = repository
        else:
            source = ch

        logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
        save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type'])

        logging.info('Drawing ...')
        draw_to_directory(
            workdir, export_dir, step,
            n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps,
            lora_alpha, sample_method, model_hash
        )