PTI / models /StyleCLIP /optimization /run_optimization.py
ucalyptus's picture
simp
2d7efb8
raw history blame
No virus
5.07 kB
import argparse
import math
import os
import pickle
import torch
import torchvision
from torch import optim
from tqdm import tqdm
from StyleCLIP.criteria.clip_loss import CLIPLoss
from StyleCLIP.models.stylegan2.model import Generator
import clip
from StyleCLIP.utils import ensure_checkpoint_exists
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def main(args, use_old_G):
ensure_checkpoint_exists(args.ckpt)
text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
os.makedirs(args.results_dir, exist_ok=True)
new_generator_path = f'/disk2/danielroich/Sandbox/stylegan2_ada_pytorch/checkpoints/model_{args.run_id}_{args.image_name}.pt'
old_generator_path = '/disk2/danielroich/Sandbox/pretrained_models/ffhq.pkl'
if not use_old_G:
with open(new_generator_path, 'rb') as f:
G = torch.load(f).cuda().eval()
else:
with open(old_generator_path, 'rb') as f:
G = pickle.load(f)['G_ema'].cuda().eval()
if args.latent_path:
latent_code_init = torch.load(args.latent_path).cuda()
elif args.mode == "edit":
latent_code_init_not_trunc = torch.randn(1, 512).cuda()
with torch.no_grad():
latent_code_init = G.mapping(latent_code_init_not_trunc, None)
latent = latent_code_init.detach().clone()
latent.requires_grad = True
clip_loss = CLIPLoss(args)
optimizer = optim.Adam([latent], lr=args.lr)
pbar = tqdm(range(args.step))
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
img_gen = G.synthesis(latent, noise_mode='const')
c_loss = clip_loss(img_gen, text_inputs)
if args.mode == "edit":
l2_loss = ((latent_code_init - latent) ** 2).sum()
loss = c_loss + args.l2_lambda * l2_loss
else:
loss = c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(
(
f"loss: {loss.item():.4f};"
)
)
if args.save_intermediate_image_every > 0 and i % args.save_intermediate_image_every == 0:
with torch.no_grad():
img_gen = G.synthesis(latent, noise_mode='const')
torchvision.utils.save_image(img_gen,
f"/disk2/danielroich/Sandbox/StyleCLIP/results/inference_results/{str(i).zfill(5)}.png",
normalize=True, range=(-1, 1))
if args.mode == "edit":
with torch.no_grad():
img_orig = G.synthesis(latent_code_init, noise_mode='const')
final_result = torch.cat([img_orig, img_gen])
else:
final_result = img_gen
return final_result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--description", type=str, default="a person with purple hair",
help="the text that guides the editing/generation")
parser.add_argument("--ckpt", type=str, default="../pretrained_models/stylegan2-ffhq-config-f.pt",
help="pretrained StyleGAN2 weights")
parser.add_argument("--stylegan_size", type=int, default=1024, help="StyleGAN resolution")
parser.add_argument("--lr_rampup", type=float, default=0.05)
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--step", type=int, default=300, help="number of optimization steps")
parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"],
help="choose between edit an image an generate a free one")
parser.add_argument("--l2_lambda", type=float, default=0.008,
help="weight of the latent distance (used for editing only)")
parser.add_argument("--latent_path", type=str, default=None,
help="starts the optimization from the given latent code if provided. Otherwose, starts from"
"the mean latent in a free generation, and from a random one in editing. "
"Expects a .pt format")
parser.add_argument("--truncation", type=float, default=0.7,
help="used only for the initial latent vector, and only when a latent code path is"
"not provided")
parser.add_argument("--save_intermediate_image_every", type=int, default=20,
help="if > 0 then saves intermidate results during the optimization")
parser.add_argument("--results_dir", type=str, default="results")
args = parser.parse_args()
result_image = main(args)
torchvision.utils.save_image(result_image.detach().cpu(), os.path.join(args.results_dir, "final_result.jpg"),
normalize=True, scale_each=True, range=(-1, 1))