File size: 3,551 Bytes
8483373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import sys
import os 
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import torchvision
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from lora_w2w import LoRAw2w
from utils import load_models, inference, save_model_w2w, save_model_for_diffusers
from inversion import invert
import argparse


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument("--mean_path", default="/files/mean.pt", type=str, help="Path to file with parameter means")
    parser.add_argument("--std_path", default="/files/std.pt", type=str, help="Path to file with parameter standard deviations.")
    parser.add_argument("--v_path", default="/files/V.pt", type=str, help="Path to V orthogonal projection/unprojection matrix.")
    parser.add_argument("--dim_path", default="/files/weight_dimensions.pt", type=str, help="Path to file with dimensions of LoRA layers. Used for saving in Diffusers pipeline format.")
    parser.add_argument("--imfolder", default="/inversion/images/real_image/real/", type=str, help="Path to folder containing image.")
    parser.add_argument("--mask_path", default=None, type=str, help="Path to mask file.")
    parser.add_argument("--epochs", default=400, type=int)
    parser.add_argument("--lr", default= 1e-1, type=float)
    parser.add_argument("--weight_decay", default= 1e-10, type=float)
    parser.add_argument("--dim", default= 10000, type=int, help="Number of principal component coefficients to optimize.")
    parser.add_argument("--diffusers_format", default=False, action="store_true", help="Whether to save in mode that can be loaded in Diffusers pipeline")
    parser.add_argument("--save_name", default="/files/inversion1.pt", type=str, help="Output path + filename.")



    ### variables
    args = parser.parse_args()
    device = args.device
    mean_path = args.mean_path
    std_path = args.std_path
    v_path = args.v_path
    dim_path = args.dim_path
    imfolder = args.imfolder
    mask_path = args.mask_path
    epochs = args.epochs
    lr = args.lr
    weight_decay = args.weight_decay
    dim = args.dim
    diffusers_format = args.diffusers_format
    save_name = args.save_name


    ### load models
    unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)

    ### load files
    mean = torch.load(mean_path).bfloat16().to(device)
    std = torch.load(std_path).bfloat16().to(device)
    v = torch.load(v_path).bfloat16().to(device)
    weight_dimensions = torch.load(dim_path)

    ### initialize network

    proj = torch.zeros(1,dim).bfloat16().to(device)
    network = LoRAw2w( proj, mean, std, v[:,:dim], 
                        unet,
                        rank=1,
                        multiplier=1.0,
                        alpha=27.0,
                        train_method="xattn-strict"
                    ).to(device, torch.bfloat16)
    ### run inversion 
    network = invert(network=network, unet=unet, vae=vae, 
                     text_encoder=text_encoder, tokenizer=tokenizer, 
                     prompt = "sks person", noise_scheduler = noise_scheduler, epochs=epochs, 
                     image_path = imfolder, mask_path = mask_path, device = device)
    
    
    ### save model

    if diffusers_format:
        save_model_for_diffusers(network,std, mean, v, weight_dimensions,
                                path=save_name)
    else: 
        save_model_w2w(network, path=save_name)



if __name__ == "__main__":
    main()