TR0N / app.py
NoelVouitsis's picture
Update app.py
d940272
import os
import re
import gradio as gr
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.transforms import transforms as T
import clip
from tr0n.config import parse_args
from tr0n.modules.models.model_stylegan import Model
from tr0n.modules.models.loss import AugCosineSimLatent
from tr0n.modules.optimizers.sgld import SGLD
from bad_words import bad_words
device = "cuda" if torch.cuda.is_available() else "cpu"
model_modes = {
"text": {
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth",
},
"image": {
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth",
}
}
os.environ['TOKENIZERS_PARALLELISM'] = "false"
# set config params
config = parse_args(is_demo=True)
config_vars = vars(config)
config_vars["stylegan_gen"] = "sg2-ffhq-1024"
config_vars["with_gmm"] = True
config_vars["num_mixtures"] = 10
model = Model(config, device, None)
model.to(device)
model.eval()
for p in model.translator.parameters():
p.requires_grad = False
loss = AugCosineSimLatent()
transforms_image = T.Compose([
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu")
translator_state_dict_text = checkpoint_text['translator_state_dict']
checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu")
translator_state_dict_image = checkpoint_image['translator_state_dict']
# default
model.translator.load_state_dict(translator_state_dict_text)
css = """
a {
display: inline-block;
color: black !important;
text-decoration: none !important;
}
#image-gen {
height: 256px;
width: 256px;
margin-left: auto;
margin-right: auto;
}
"""
def _slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
def model_mode_text_select():
model.translator.load_state_dict(translator_state_dict_text)
def model_mode_image_select():
model.translator.load_state_dict(translator_state_dict_image)
def text_to_face_generate(text):
if text == "":
raise gr.Error("You need to provide to provide a prompt.")
for word in bad_words:
if re.search(rf"\b{word}\b", text):
raise gr.Error("Unsafe content found. Please try again with a different prompt.")
text_tok = clip.tokenize([text], truncate=True).to(device)
# initialize optimization from the translator's output
with torch.no_grad():
target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True)
pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) # 1 x num_mixtures x w_dim
w = w_means # 1 x num_mixtures x w_dim
w.requires_grad = True
pi.requires_grad = True
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
optimizer_pi = Adam((pi,), lr=5e-3)
# optimization
for _ in range(100):
soft_pi = F.softmax(pi, dim=1)
w_prime = soft_pi * w
w_prime = w_prime.sum(dim=1)
_, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50)
l = loss(target_clip_latent, pred_clip_latent)
l.backward()
torch.nn.utils.clip_grad_norm_((w,), 1.)
torch.nn.utils.clip_grad_norm_((pi,), 1.)
optimizer_w.step()
optimizer_pi.step()
optimizer_w.zero_grad()
optimizer_pi.zero_grad()
# generate final image
with torch.no_grad():
soft_pi = F.softmax(pi, dim=1)
w_prime = soft_pi * w
w_prime = w_prime.sum(dim=1)
_, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent')
pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
return T.ToPILImage()(pred_image)
def face_to_face_interpolate(image1, image2, interp_lambda=0.5):
if image1 is None or image2 is None:
raise gr.Error("You need to provide two images as input.")
image1_pt = transforms_image(image1).to(device)
image2_pt = transforms_image(image2).to(device)
# initialize optimization from the translator's output
with torch.no_grad():
images_pt = torch.stack([image1_pt, image2_pt])
target_clip_latents = model.clip.encode_image(images_pt).detach().float()
target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0))
_, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True)
w.requires_grad = True
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
# optimization
for _ in range(100):
_, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50)
l = loss(target_clip_latent, pred_clip_latent)
l.backward()
torch.nn.utils.clip_grad_norm_((w,), 1.)
optimizer_w.step()
optimizer_w.zero_grad()
# generate final image
with torch.no_grad():
_, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent')
pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
return T.ToPILImage()(pred_image)
examples_text = [
"Muhammad Ali",
"Tinker Bell",
"A man with glasses, long black hair with sideburns and a goatee",
"A child with blue eyes and straight brown hair in the sunshine",
"A hairdresser",
"A young boy with glasses and an angry face",
"Denzel Washington",
"A portrait of Angela Merkel",
"President Emmanuel Macron",
"President Xi Jinping"
]
examples_image = [
["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"],
["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"],
["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"],
["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"],
]
with gr.Blocks(css=css) as demo:
gr.Markdown("<h1><center>TR0N Face Generation Demo</center></h1>")
gr.Markdown("<h3><center><a href='https://layer6.ai/'>by Layer 6 AI</a></center></h3>")
gr.Markdown("""<p align='middle'>
<a href='https://arxiv.org/abs/2304.13742'><img src='https://img.shields.io/badge/arXiv-2304.13742-b31b1b.svg' /></a>
<a href='https://github.com/layer6ai-labs/tr0n'><img src='https://badgen.net/badge/icon/github?icon=github&label' /></a>
</p>""")
gr.Markdown("We introduce TR0N, a simple and efficient method to add any type of conditioning to pre-trained generative models. For this demo, we add two types of conditioning to a StyleGAN2 model pre-trained on images of human faces. First, we add text-conditioning to turn StyleGAN2 into a text-to-face model. Second, we add image semantic conditioning to StyleGAN2 to enable face-to-face interpolation. For more details and results on many other generative models, please refer to our paper linked above.")
with gr.Tab("Text-to-face generation") as text_to_face_generation_demo:
text_to_face_generation_input = gr.Textbox(label="Enter your prompt", placeholder="e.g. A man with a beard and glasses", max_lines=1)
text_to_face_generation_button = gr.Button("Generate")
text_to_face_generation_output = gr.Image(label="Generated image", elem_id="image-gen")
text_to_face_generation_examples = gr.Examples(examples=examples_text, fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
with gr.Tab("Face-to-face interpolation") as face_to_face_interpolation_demo:
gr.Markdown("We note that interpolations are not expected to recover the given images, even when the coefficient is 0 or 1.")
with gr.Row():
face_to_face_interpolation_input1 = gr.Image(label="Image 1", type="pil")
face_to_face_interpolation_input2 = gr.Image(label="Image 2", type="pil")
face_to_face_interpolation_lambda = gr.Slider(label="Interpolation coefficient", minimum=0, maximum=1, value=0.5, step=0.01)
face_to_face_interpolation_button = gr.Button("Interpolate")
face_to_face_interpolation_output = gr.Image(label="Interpolated image", elem_id="image-gen")
face_to_face_interpolation_examples = gr.Examples(examples=examples_image, fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
text_to_face_generation_demo.select(fn=model_mode_text_select)
text_to_face_generation_input.submit(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
text_to_face_generation_button.click(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
face_to_face_interpolation_demo.select(fn=model_mode_image_select)
face_to_face_interpolation_button.click(fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
demo.queue()
demo.launch()