Spaces:
Runtime error
Runtime error
import datetime | |
import os | |
import pathlib | |
import pytz | |
from typing import Optional | |
from ditk import logging | |
from hbutils.system import TemporaryDirectory | |
from huggingface_hub import CommitOperationAdd, CommitOperationDelete | |
from huggingface_hub.utils import RepositoryNotFoundError | |
from .export import export_workdir, _GITLFS | |
from .steps import find_steps_in_workdir | |
from ..infer.draw import _DEFAULT_INFER_MODEL | |
from ..utils import get_hf_client, get_hf_fs | |
def deploy_to_huggingface(workdir: str, repository=None, revision: str = 'main', n_repeats: int = 3, | |
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2, | |
image_width: int = 512, image_height: int = 768, infer_steps: int = 30, | |
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras', | |
model_hash: Optional[str] = None, ds_dir: str = None): | |
name, _ = find_steps_in_workdir(workdir) | |
repository = repository or f'AppleHarem/{name}' | |
logging.info(f'Initializing repository {repository!r} ...') | |
hf_client = get_hf_client() | |
hf_fs = get_hf_fs() | |
if not hf_fs.exists(f'{repository}/.gitattributes'): | |
hf_client.create_repo(repo_id=repository, repo_type='model', exist_ok=True) | |
if not hf_fs.exists(f'{repository}/.gitattributes') or \ | |
'*.png filter=lfs diff=lfs merge=lfs -text' not in hf_fs.read_text(f'{repository}/.gitattributes'): | |
logging.info(f'Preparing for lfs attributes of repository {repository!r}.') | |
with TemporaryDirectory() as td: | |
_git_attr_file = os.path.join(td, '.gitattributes') | |
with open(_git_attr_file, 'w', encoding='utf-8') as f: | |
print(_GITLFS, file=f) | |
operations = [ | |
CommitOperationAdd( | |
path_in_repo='.gitattributes', | |
path_or_fileobj=_git_attr_file, | |
) | |
] | |
tokyo_tz = pytz.timezone('Asia/Tokyo') | |
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') | |
commit_message = f'Update {name}\'s .gitattributes, on {current_time}' | |
logging.info(f'Updating {name}\'s .gitattributes to repository {repository!r} ...') | |
hf_client.create_commit( | |
repository, | |
operations, | |
commit_message=commit_message, | |
repo_type='model', | |
revision=revision, | |
) | |
with TemporaryDirectory() as td: | |
export_workdir( | |
workdir, td, n_repeats, pretrained_model, | |
clip_skip, image_width, image_height, infer_steps, | |
lora_alpha, sample_method, model_hash, ds_repo=ds_dir, # ds_repo: 本地数据集或远端数据集 | |
) | |
try: | |
hf_client.repo_info(repo_id=repository, repo_type='dataset') | |
except RepositoryNotFoundError: | |
has_dataset_repo = False | |
else: | |
has_dataset_repo = True | |
readme_text = pathlib.Path(os.path.join(td, 'README.md')).read_text(encoding='utf-8') | |
with open(os.path.join(td, 'README.md'), 'w', encoding='utf-8') as f: | |
print('---', file=f) | |
print('license: mit', file=f) | |
if has_dataset_repo: | |
print('datasets:', file=f) | |
print(f'- {repository}', file=f) | |
print('pipeline_tag: text-to-image', file=f) | |
print('tags:', file=f) | |
print('- art', file=f) | |
print('---', file=f) | |
print('', file=f) | |
print(readme_text, file=f) | |
_exist_files = [os.path.relpath(file, repository) for file in hf_fs.glob(f'{repository}/**')] | |
_exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) | |
pre_exist_files = set() | |
for i, (file, segments) in enumerate(_exist_ps): | |
if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: | |
continue | |
if file != '.': | |
pre_exist_files.add(file) | |
operations = [] | |
for directory, _, files in os.walk(td): | |
for file in files: | |
filename = os.path.abspath(os.path.join(td, directory, file)) | |
file_in_repo = os.path.relpath(filename, td) | |
operations.append(CommitOperationAdd( | |
path_in_repo=file_in_repo, | |
path_or_fileobj=filename, | |
)) | |
if file_in_repo in pre_exist_files: | |
pre_exist_files.remove(file_in_repo) | |
logging.info(f'Useless files: {sorted(pre_exist_files)} ...') | |
for file in sorted(pre_exist_files): | |
operations.append(CommitOperationDelete(path_in_repo=file)) | |
tokyo_tz = pytz.timezone('Asia/Tokyo') | |
current_time = datetime.datetime.now().astimezone(tokyo_tz).strftime('%Y-%m-%d %H:%M:%S %Z') | |
commit_message = f'Publish {name}\'s lora, on {current_time}' | |
logging.info(f'Publishing {name}\'s lora to repository {repository!r} ...') | |
hf_client.create_commit( | |
repository, | |
operations, | |
commit_message=commit_message, | |
repo_type='model', | |
revision=revision, | |
) | |