File size: 6,628 Bytes
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from adaface.adaface_wrapper import AdaFaceWrapper
import torch
#import torch.nn.functional as F
from PIL import Image
import numpy as np
import os, argparse, glob, re

def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
    if num_images_per_row > len(images):
        num_images_per_row = len(images)
        
    os.makedirs(save_dir, exist_ok=True)
        
    num_columns = int(np.ceil(len(images) / num_images_per_row))
    # Save 4 images as a grid image in save_dir
    grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
    for i, image in enumerate(images):
        image = image.resize((512, 512))
        grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))

    prompt_sig = prompt.replace(" ", "_").replace(",", "_")
    grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
    if os.path.exists(grid_filepath):
        grid_count = 2
        grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
        while os.path.exists(grid_filepath):
            grid_count += 1
            grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')

    grid_image.save(grid_filepath)
    print(f"Saved to {grid_filepath}")

def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PL_GLOBAL_SEED"] = str(seed)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5', 
                        help="Type of checkpoints to use (default: SD 1.5)")
    parser.add_argument("--embman_ckpt", type=str, required=True,
                        help="Path to the checkpoint of the embedding manager")
    parser.add_argument("--subject", type=str, required=True)
    parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
    parser.add_argument("--out_image_count",     type=int, default=4,  help="Number of images to generate")
    parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
    parser.add_argument("--noise", dest='noise_level', type=float, default=0)
    parser.add_argument("--randface", action="store_true")
    parser.add_argument("--scale", dest='guidance_scale', type=float, default=4, 
                        help="Guidance scale for the diffusion model")
    parser.add_argument("--id_cfg_scale", type=float, default=1, 
                        help="CFG scale when generating the identity embeddings")
    
    parser.add_argument("--subject_string", 
                        type=str, default="z",
                        help="Subject placeholder string used in prompts to denote the concept.")
    parser.add_argument("--num_vectors", type=int, default=16,
                        help="Number of vectors used to represent the subject.")
    parser.add_argument("--num_images_per_row", type=int, default=4,
                        help="Number of images to display in a row in the output grid image.")
    parser.add_argument("--num_inference_steps", type=int, default=50,
                        help="Number of DDIM inference steps")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
    parser.add_argument("--seed", type=int, default=42, 
                        help="the seed (for reproducible sampling). Set to -1 to disable.")
    args = parser.parse_args()
    
    return args

if __name__ == "__main__":
    args = parse_args()
    if args.seed != -1:
        seed_everything(args.seed)

    if re.match(r"^\d+$", args.device):
        args.device = f"cuda:{args.device}"
    print(f"Using device {args.device}")

    adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device, 
                             args.subject_string, args.num_vectors, args.num_inference_steps)

    if not args.randface:
        image_folder = args.subject
        if image_folder.endswith("/"):
            image_folder = image_folder[:-1]

        if os.path.isfile(image_folder):
            # Get the second to the last part of the path
            subject_name = os.path.basename(os.path.dirname(image_folder))
            image_paths = [image_folder]

        else:
            subject_name = os.path.basename(image_folder)
            image_types = ["*.jpg", "*.png", "*.jpeg"]
            alltype_image_paths = []
            for image_type in image_types:
                # glob returns the full path.
                image_paths = glob.glob(os.path.join(image_folder, image_type))
                if len(image_paths) > 0:
                    alltype_image_paths.extend(image_paths)

            # Filter out images of "*_mask.png"
            alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]

            # image_paths contain at most args.example_image_count full image paths.
            if args.example_image_count > 0:
                image_paths = alltype_image_paths[:args.example_image_count]
            else:
                image_paths = alltype_image_paths
    else:
        subject_name = None
        image_paths = None
        image_folder = None

    subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
    rand_face_embs = torch.randn(1, 512)

    pre_face_embs = rand_face_embs if args.randface else None
    noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
    # args.noise_level: the *relative* std of the noise added to the face embeddings.
    # A noise level of 0.08 could change gender, but 0.06 is usually safe.
    # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
    adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface, 
                                                            out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level, 
                                                            update_text_encoder=True)    
    images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
    save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)