LittleApple_fp16
upload
69a6cef
raw
history blame contribute delete
No virus
4.07 kB
import json
import logging
import os
from typing import Optional
from hbutils.system import TemporaryDirectory
from huggingface_hub import hf_hub_url
from tqdm.auto import tqdm
from .draw import _DEFAULT_INFER_MODEL, draw_with_workdir
from ..dataset import save_recommended_tags
from ..utils import get_hf_fs, download_file
def draw_to_directory(workdir: str, export_dir: str, step: int, n_repeats: int = 2,
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
model_hash: Optional[str] = None):
from ..publish.export import KNOWN_MODEL_HASHES
model_hash = model_hash or KNOWN_MODEL_HASHES.get(pretrained_model)
os.makedirs(export_dir, exist_ok=True)
while True:
try:
drawings = draw_with_workdir(
workdir, model_steps=step, n_repeats=n_repeats,
pretrained_model=pretrained_model,
width=image_width, height=image_height, infer_steps=infer_steps,
lora_alpha=lora_alpha, clip_skip=clip_skip, sample_method=sample_method,
model_hash=model_hash,
)
except RuntimeError:
n_repeats += 1
else:
break
all_image_files = []
for draw in drawings:
img_file = os.path.join(export_dir, f'{draw.name}.png')
draw.image.save(img_file, pnginfo=draw.pnginfo)
all_image_files.append(img_file)
with open(os.path.join(export_dir, f'{draw.name}_info.txt'), 'w', encoding='utf-8') as f:
print(draw.preview_info, file=f)
def draw_with_repo(repository: str, export_dir: str, step: Optional[int] = None, n_repeats: int = 2,
pretrained_model: str = _DEFAULT_INFER_MODEL, clip_skip: int = 2,
image_width: int = 512, image_height: int = 768, infer_steps: int = 30,
lora_alpha: float = 0.85, sample_method: str = 'DPM++ 2M Karras',
model_hash: Optional[str] = None):
from ..publish import find_steps_in_workdir
hf_fs = get_hf_fs()
if not hf_fs.exists(f'{repository}/meta.json'):
raise ValueError(f'Invalid repository or no model found - {repository!r}.')
logging.info(f'Model repository {repository!r} found.')
meta = json.loads(hf_fs.read_text(f'{repository}/meta.json'))
step = step or meta['best_step']
logging.info(f'Using step {step} ...')
with TemporaryDirectory() as workdir:
logging.info('Downloading models ...')
for f in tqdm(hf_fs.glob(f'{repository}/{step}/raw/*')):
rel_file = os.path.relpath(f, repository)
local_file = os.path.join(workdir, 'ckpts', os.path.basename(rel_file))
if os.path.dirname(local_file):
os.makedirs(os.path.dirname(local_file), exist_ok=True)
download_file(
hf_hub_url(repository, filename=rel_file),
local_file
)
logging.info(f'Regenerating tags for {workdir!r} ...')
pt_name, _ = find_steps_in_workdir(workdir)
game_name = pt_name.split('_')[-1]
name = '_'.join(pt_name.split('_')[:-1])
from gchar.games.dispatch.access import GAME_CHARS
if game_name in GAME_CHARS:
ch_cls = GAME_CHARS[game_name]
ch = ch_cls.get(name)
else:
ch = None
if ch is None:
source = repository
else:
source = ch
logging.info(f'Regenerate tags for {source!r}, on {workdir!r}.')
save_recommended_tags(source, name=pt_name, workdir=workdir, ds_size=meta["dataset"]['type'])
logging.info('Drawing ...')
draw_to_directory(
workdir, export_dir, step,
n_repeats, pretrained_model, clip_skip, image_width, image_height, infer_steps,
lora_alpha, sample_method, model_hash
)