Joseph Catrambone
First import. Add ControlNetSD21 Laion Face (full, pruned, and safetensors). Add README and samples. Add surrounding tools for example use.
568dc2c
import json | |
import os | |
import sys | |
from dataclasses import dataclass, field | |
from glob import glob | |
from typing import Mapping | |
from PIL import Image | |
from tqdm import tqdm | |
from laion_face_common import generate_annotation | |
class RunProgress: | |
pending: list = field(default_factory=list) | |
success: list = field(default_factory=list) | |
skipped_size: list = field(default_factory=list) | |
skipped_nsfw: list = field(default_factory=list) | |
skipped_noface: list = field(default_factory=list) | |
skipped_smallface: list = field(default_factory=list) | |
def main( | |
status_filename: str, | |
prompt_filename: str, | |
input_glob: str, | |
output_directory: str, | |
annotated_output_directory: str = "", | |
min_image_size: int = 384, | |
max_image_size: int = 32766, | |
min_face_size_pixels: int = 64, | |
prompt_mapping: dict = None, # If present, maps a filename to a text prompt. | |
): | |
status = RunProgress() | |
if os.path.exists(status_filename): | |
print("Continuing from checkpoint.") | |
# Restore a saved state: | |
status_temp = json.load(open(status_filename, 'rt')) | |
for k in status.__dict__.keys(): | |
status.__setattr__(k, status_temp[k]) | |
# Output label file: | |
pout = open(prompt_filename, 'at') | |
else: | |
print("Starting run.") | |
status = RunProgress() | |
status.pending = list(glob(input_glob)) | |
# Output label file: | |
pout = open(prompt_filename, 'wt') | |
with open(status_filename, 'wt') as fout: | |
json.dump(status.__dict__, fout) | |
print(f"{len(status.pending)} images remaining") | |
# If we don't have a preexisting set of labels (like for ImageNet/MSCOCO), just null-fill the mapping. | |
# We will try on a per-image basis to see if there's a metadata .json. | |
if prompt_mapping is None: | |
prompt_mapping = dict() | |
step = 0 | |
with tqdm(total=len(status.pending)) as pbar: | |
while len(status.pending) > 0: | |
full_filename = status.pending.pop() | |
pbar.update(1) | |
step += 1 | |
if step % 100 == 0: | |
# Checkpoint save: | |
with open(status_filename, 'wt') as fout: | |
json.dump(status.__dict__, fout) | |
_fpath, fname = os.path.split(full_filename) | |
# Make our output filenames. | |
# We used to do this here so we could check if a file existed before writing, then skip it, but since we | |
# have a 'status' that we cache and update, we no longer have to do this check. | |
annotation_filename = "" | |
if annotated_output_directory: | |
annotation_filename = os.path.join(annotated_output_directory, fname) | |
output_filename = os.path.join(output_directory, fname) | |
# The LAION dataset has accompanying .json files with each image. | |
partial_filename, extension = os.path.splitext(full_filename) | |
candidate_json_fullpath = partial_filename + ".json" | |
image_metadata = {} | |
if os.path.exists(candidate_json_fullpath): | |
try: | |
image_metadata = json.load(open(candidate_json_fullpath, 'rt')) | |
except Exception as e: | |
print(e) | |
if "NSFW" in image_metadata: | |
nsfw_marker = image_metadata.get("NSFW") # This can be "", None, or other weird things. | |
if nsfw_marker is not None and nsfw_marker.lower() != "unlikely": | |
# Skip NSFW images. | |
status.skipped_nsfw.append(full_filename) | |
continue | |
# Try to get a prompt/caption from the metadata or the prompt mapping. | |
image_prompt = image_metadata.get("caption", prompt_mapping.get(fname, "")) | |
# Load image: | |
img = Image.open(full_filename).convert("RGB") | |
img_width = img.size[0] | |
img_height = img.size[1] | |
img_size = min(img.size[0], img.size[1]) | |
if img_size < min_image_size or max(img_width, img_height) > max_image_size: | |
status.skipped_size.append(full_filename) | |
continue | |
# We re-initialize the detector every time because it has a habit of triggering weird race conditions. | |
empty, annotated, faces_before_filtering, faces_after_filtering = generate_annotation( | |
img, | |
max_faces=5, | |
min_face_size_pixels=min_face_size_pixels, | |
return_annotation_data=True | |
) | |
if faces_before_filtering == 0: | |
# Skip images with no faces. | |
status.skipped_noface.append(full_filename) | |
continue | |
if faces_after_filtering == 0: | |
# Skip images with no faces large enough | |
status.skipped_smallface.append(full_filename) | |
continue | |
Image.fromarray(empty).save(output_filename) | |
if annotation_filename: | |
Image.fromarray(annotated).save(annotation_filename) | |
# See https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for the training file format. | |
# prompt.json | |
# a JSONL file with {"source": "source/0.jpg", "target": "target/0.jpg", "prompt": "..."}. | |
# a source/xxxxx.jpg or source/xxxx.png file for each of the inputs. | |
# a target/xxxxx.jpg for each of the outputs. | |
pout.write(json.dumps({ | |
"source": os.path.join(output_directory, fname), | |
"target": full_filename, | |
"prompt": image_prompt, | |
}) + "\n") | |
pout.flush() | |
status.success.append(full_filename) | |
# We do save every 100 iterations, but it's good to save on completion, too. | |
with open(status_filename, 'wt') as fout: | |
json.dump(status.__dict__, fout) | |
pout.close() | |
print("Done!") | |
print(f"{len(status.success)} images added to dataset.") | |
print(f"{len(status.skipped_size)} images rejected for size.") | |
print(f"{len(status.skipped_smallface)} images rejected for having faces too small.") | |
print(f"{len(status.skipped_noface)} images rejected for not having faces.") | |
print(f"{len(status.skipped_nsfw)} images rejected for NSFW.") | |
if __name__ == "__main__": | |
if len(sys.argv) >= 3 and "-h" not in sys.argv: | |
prompt_jsonl = sys.argv[1] | |
in_glob = sys.argv[2] # Should probably be in a directory called "target/*.jpg". | |
output_dir = sys.argv[3] # Should probably be a directory called "source". | |
annotation_dir = "" | |
if len(sys.argv) > 4: | |
annotation_dir = sys.argv[4] | |
main("generate_face_poses_checkpoint.json", prompt_jsonl, in_glob, output_dir, annotation_dir) | |
else: | |
print(f"""Usage: | |
python {sys.argv[0]} prompt.jsonl target/*.jpg source/ [annotated/] | |
source and target are slightly confusing in this context. We are writing the image names to prompt.jsonl, so | |
the naming system has to be consistent with what ControlNet expects. In ControlNet, the source is the input and | |
target is the output. We are generating source images from targets in this application, so the second argument | |
should be a folder full of images. The third argument should be 'source', where the images should be places. | |
Optionally, an 'annotated' directory can be provided. Augmented images will be placed here. | |
A checkpoint file named 'generate_face_poses_checkpoint.json' will be created in the place where the script is | |
run. If a run is cancelled, it can be resumed from this checkpoint. | |
If invoking the script from bash, do not forget to enclose globs with quotes. Example usage: | |
`python ./tool_generate_face_poses.py ./face_prompt.jsonl "/home/josephcatrambone/training_data/data-mscoco/images/train2017/*" /home/josephcatrambone/training_data/data-mscoco/images/source_2017/` | |
""") | |