AppleJupyter-test / cyberharem /publish /cyberharem_publish_huggingface.py
LittleApple_fp16
upload
69a6cef
raw
history blame
5.31 kB
import datetime
import os
import pathlib
import pytz
from typing import Optional
from ditk import logging
from hbutils.system import TemporaryDirectory
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
from huggingface_hub.utils import RepositoryNotFoundError
from .export import export_workdir, _GITLFS
from .steps import find_steps_in_workdir
from ..infer.draw import _DEFAULT_INFER_MODEL
from ..utils import get_hf_client, get_hf_fs
def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3,
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
model_hash: Optional[str] = None, ds_dir: str = None):
name, _ = find_steps_in_workdir(workdir)
repository = repository or f'AppleHarem/{name}'
logging.info(f'Initializing repository {repository!r} ...')
hf_client = get_hf_client()
hf_fs = get_hf_fs()
if not hf_fs.exists(f'{repository}/.gitattributes'):
hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True)
if not hf_fs.exists(f'{repository}/.gitattributes') or \
'*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'):
logging.info(f'Preparing for lfs attributes of repository {repository!r}.')
with TemporaryDirectory() as td:
_git_attr_file = os.path.join(td, '.gitattributes')
with open(_git_attr_file, 'w', encoding='utf-8') as f:
print(_GITLFS, file=f)
operations = [
CommitOperationAdd(
path_in_repo='.gitattributes',
path_or_fileobj=_git_attr_file,
)
]
tokyo_tz = pytz.timezone('Asia/Tokyo')
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
commit_message = f'Update {name}\'s .gitattributes, on {current_time}'
logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...')
hf_client.create_commit(
repository,
operations,
commit_message=commit_message,
repo_type='model',
revision=revision,
)
with TemporaryDirectory() as td:
export_workdir(
workdir, td, n_repeats, pretrained_model,
clip_skip, image_width, image_height, infer_steps,
lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集
)
try:
hf_client.repo_info(repo_id=repository, repo_type='dataset')
except RepositoryNotFoundError:
has_dataset_repo = False
else:
has_dataset_repo = True
readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8')
with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f:
print('---', file=f)
print('license: mit', file=f)
if has_dataset_repo:
print('datasets:', file=f)
print(f'- {repository}', file=f)
print('pipeline_tag: text-to-image', file=f)
print('tags:', file=f)
print('- art', file=f)
print('---', file=f)
print('', file=f)
print(readme_text, file=f)
_exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
_exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
pre_exist_files = set()
for i, (file, segments) in enumerate(_exist_ps):
if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
continue
if file != '.':
pre_exist_files.add(file)
operations = []
for directory, _, files in os.walk(td):
for file in files:
filename = os.path.abspath(os.path.join(td, directory, file))
file_in_repo = os.path.relpath(filename, td)
operations.append(CommitOperationAdd(
path_in_repo=file_in_repo,
path_or_fileobj=filename,
))
if file_in_repo in pre_exist_files:
pre_exist_files.remove(file_in_repo)
logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
for file in sorted(pre_exist_files):
operations.append(CommitOperationDelete(path_in_repo=file))
tokyo_tz = pytz.timezone('Asia/Tokyo')
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
commit_message = f'Publish {name}\'s lora, on {current_time}'
logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...')
hf_client.create_commit(
repository,
operations,
commit_message=commit_message,
repo_type='model',
revision=revision,
)