AppleJupyter-test / cyberharem /publish /cyberharem_publish_huggingface.py
LittleApple_fp16
upload
69a6cef
raw
history blame contribute delete
No virus
5.31 kB
import datetime
import os
import pathlib
import pytz
from typing import Optional
from ditk import logging
from hbutils.system import TemporaryDirectory
from huggingface_hub import CommitOperationAdd, CommitOperationDelete
from huggingface_hub.utils import RepositoryNotFoundError
from .export import export_workdir, _GITLFS
from .steps import find_steps_in_workdir
from ..infer.draw import _DEFAULT_INFER_MODEL
from ..utils import get_hf_client, get_hf_fs
def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3,
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
model_hash: Optional[str] = None, ds_dir: str = None):
name, _ = find_steps_in_workdir(workdir)
repository = repository or f'AppleHarem/{name}'
logging.info(f'Initializing repository {repository!r} ...')
hf_client = get_hf_client()
hf_fs = get_hf_fs()
if not hf_fs.exists(f'{repository}/.gitattributes'):
hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True)
if not hf_fs.exists(f'{repository}/.gitattributes') or \
'*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'):
logging.info(f'Preparing for lfs attributes of repository {repository!r}.')
with TemporaryDirectory() as td:
_git_attr_file = os.path.join(td, '.gitattributes')
with open(_git_attr_file, 'w', encoding='utf-8') as f:
print(_GITLFS, file=f)
operations = [
CommitOperationAdd(
path_in_repo='.gitattributes',
path_or_fileobj=_git_attr_file,
)
]
tokyo_tz = pytz.timezone('Asia/Tokyo')
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
commit_message = f'Update {name}\'s .gitattributes, on {current_time}'
logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...')
hf_client.create_commit(
repository,
operations,
commit_message=commit_message,
repo_type='model',
revision=revision,
)
with TemporaryDirectory() as td:
export_workdir(
workdir, td, n_repeats, pretrained_model,
clip_skip, image_width, image_height, infer_steps,
lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集
)
try:
hf_client.repo_info(repo_id=repository, repo_type='dataset')
except RepositoryNotFoundError:
has_dataset_repo = False
else:
has_dataset_repo = True
readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8')
with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f:
print('---', file=f)
print('license: mit', file=f)
if has_dataset_repo:
print('datasets:', file=f)
print(f'- {repository}', file=f)
print('pipeline_tag: text-to-image', file=f)
print('tags:', file=f)
print('- art', file=f)
print('---', file=f)
print('', file=f)
print(readme_text, file=f)
_exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')]
_exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
pre_exist_files = set()
for i, (file, segments) in enumerate(_exist_ps):
if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
continue
if file != '.':
pre_exist_files.add(file)
operations = []
for directory, _, files in os.walk(td):
for file in files:
filename = os.path.abspath(os.path.join(td, directory, file))
file_in_repo = os.path.relpath(filename, td)
operations.append(CommitOperationAdd(
path_in_repo=file_in_repo,
path_or_fileobj=filename,
))
if file_in_repo in pre_exist_files:
pre_exist_files.remove(file_in_repo)
logging.info(f'Useless files: {sorted(pre_exist_files)} ...')
for file in sorted(pre_exist_files):
operations.append(CommitOperationDelete(path_in_repo=file))
tokyo_tz = pytz.timezone('Asia/Tokyo')
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z')
commit_message = f'Publish {name}\'s lora, on {current_time}'
logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...')
hf_client.create_commit(
repository,
operations,
commit_message=commit_message,
repo_type='model',
revision=revision,
)