LittleApple_fp16
upload
69a6cef
raw
history blame
1.2 kB
import glob
import os.path
from typing import List, Tuple
def find_steps_in_workdir(workdir: str) -> Tuple[str, List[int]]:
ckpts_dir = os.path.join(workdir, 'ckpts')
pt_steps = []
pt_name = None
for pt in glob.glob(os.path.join(ckpts_dir, '*-*.pt')):
name = os.path.basename(pt)
segs = os.path.splitext(name)[0].split('-')
if pt_name is None:
pt_name = '-'.join(segs[:-1])
else:
if pt_name != '-'.join(segs[:-1]):
raise NameError(f'Name not match, {pt_name!r} vs {"-".join(segs[:-1])!r}.')
pt_steps.append(int(segs[-1]))
unet_steps = []
for unet in glob.glob(os.path.join(ckpts_dir, 'unet-*.safetensors')):
name = os.path.basename(unet)
segs = os.path.splitext(name)[0].split('-')
unet_steps.append(int(segs[-1]))
text_encoder_steps = []
for text_encoder in glob.glob(os.path.join(ckpts_dir, 'text_encoder-*.safetensors')):
name = os.path.basename(text_encoder)
segs = os.path.splitext(name)[0].split('-')
text_encoder_steps.append(int(segs[-1]))
return pt_name, sorted(set(pt_steps) & set(unet_steps) & set(text_encoder_steps))