File size: 5,309 Bytes
69a6cef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import datetime
import os
import pathlib
import pytz
from typing import Optional

from ditk import logging
from hbutils.system import TemporaryDirectory
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
from huggingface_hub.utils import RepositoryNotFoundError

from .export import export_workdir, _GITLFS
from .steps import find_steps_in_workdir
from ..infer.draw import _DEFAULT_INFER_MODEL
from ..utils import get_hf_client, get_hf_fs


def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3,
                          pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
                          image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
                          lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
                          model_hash: Optional[str] = None, ds_dir: str = None):
    name, _ = find_steps_in_workdir(workdir)
    repository = repository or f'AppleHarem/{name}'

    logging.info(f'Initializing repository {repository!r} ...')
    hf_client = get_hf_client()
    hf_fs = get_hf_fs()
    if not hf_fs.exists(f'{repository}/.gitattributes'):
        hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True)

    if not hf_fs.exists(f'{repository}/.gitattributes') or \
            '*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'):
        logging.info(f'Preparing for lfs attributes of repository {repository!r}.')
        with TemporaryDirectory() as td:
            _git_attr_file = os.path.join(td, '.gitattributes')
            with open(_git_attr_file, 'w', encoding='utf-8') as f:
                print(_GITLFS, file=f)

            operations = [
                CommitOperationAdd(
                    path_in_repo='.gitattributes',
                    path_or_fileobj=_git_attr_file,
                )
            ]
            tokyo_tz = pytz.timezone('Asia/Tokyo')
            current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
            commit_message = f'Update {name}\'s .gitattributes, on {current_time}'
            logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...')
            hf_client.create_commit(
                repository,
                operations,
                commit_message=commit_message,
                repo_type='model',
                revision=revision,
            )

    with TemporaryDirectory() as td:
        export_workdir(
            workdir, td, n_repeats, pretrained_model,
            clip_skip, image_width, image_height, infer_steps,
            lora_alpha, sample_method, model_hash, ds_repo=ds_dir,  # ds_repo: 本地数据集或远端数据集
        )

        try:
            hf_client.repo_info(repo_id=repository, repo_type='dataset')
        except RepositoryNotFoundError:
            has_dataset_repo = False
        else:
            has_dataset_repo = True

        readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8')
        with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f:
            print('---', file=f)
            print('license: mit', file=f)
            if has_dataset_repo:
                print('datasets:', file=f)
                print(f'- {repository}', file=f)
            print('pipeline_tag: text-to-image', file=f)
            print('tags:', file=f)
            print('- art', file=f)
            print('---', file=f)
            print('', file=f)
            print(readme_text, file=f)

        _exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
        _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
        pre_exist_files = set()
        for i, (file, segments) in enumerate(_exist_ps):
            if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
                continue
            if file != '.':
                pre_exist_files.add(file)

        operations = []
        for directory, _, files in os.walk(td):
            for file in files:
                filename = os.path.abspath(os.path.join(td, directory, file))
                file_in_repo = os.path.relpath(filename, td)
                operations.append(CommitOperationAdd(
                    path_in_repo=file_in_repo,
                    path_or_fileobj=filename,
                ))
                if file_in_repo in pre_exist_files:
                    pre_exist_files.remove(file_in_repo)
        logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
        for file in sorted(pre_exist_files):
            operations.append(CommitOperationDelete(path_in_repo=file))

        tokyo_tz = pytz.timezone('Asia/Tokyo')
        current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
        commit_message = f'Publish {name}\'s lora, on {current_time}'
        logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...')
        hf_client.create_commit(
            repository,
            operations,
            commit_message=commit_message,
            repo_type='model',
            revision=revision,
        )