taim-gan / app.py
Dmmc's picture
replace some examples
e910cd6
raw
history blame
No virus
3.9 kB
import numpy as np # this should come first to mitigate mlk-service bug
from src.models.utils import get_image_arr, load_model
from src.data import TAIMGANTokenizer
from torchvision import transforms
from src.config import config_dict
from pathlib import Path
from enum import IntEnum, auto
from PIL import Image
import gradio as gr
import torch
from src.models.modules import (
VGGEncoder,
InceptionEncoder,
TextEncoder,
Generator
)
##########
# PARAMS #
##########
IMG_CHANS = 3 # RGB channels for image
IMG_HW = 256 # height and width of images
HIDDEN_DIM = 128 # hidden dimensions of lstm cell in one direction
C = 2 * HIDDEN_DIM # length of embeddings
Ng = config_dict["Ng"]
cond_dim = config_dict["condition_dim"]
z_dim = config_dict["noise_dim"]
###############
# LOAD MODELS #
###############
models = {
"COCO": {
"dir": "weights/coco"
},
"Bird": {
"dir": "weights/bird"
},
"UTKFace": {
"dir": "weights/utkface"
}
}
for model_name in models:
# create tokenizer
models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle")
vocab_size = len(models[model_name]["tokenizer"].word_to_ix)
# instantiate models
models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval()
models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval()
models[model_name]["vgg"] = VGGEncoder().eval()
models[model_name]["inception"] = InceptionEncoder(D=C).eval()
# load models
load_model(
generator=models[model_name]["generator"],
discriminator=None,
image_encoder=models[model_name]["inception"],
text_encoder=models[model_name]["lstm"],
output_dir=Path(models[model_name]["dir"]),
device=torch.device("cpu")
)
def change_image_with_text(image: Image, text: str, model_name: str) -> Image:
"""
Create an image modified by text from the original image
and save it with _modified postfix
:param gr.Image image: Path to the image
:param str text: Desired caption
"""
global models
tokenizer = models[model_name]["tokenizer"]
G = models[model_name]["generator"]
lstm = models[model_name]["lstm"]
inception = models[model_name]["inception"]
vgg = models[model_name]["vgg"]
# generate some noise
noise = torch.rand(z_dim).unsqueeze(0)
# transform input text and get masks with embeddings
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
mask = (tokens == tokenizer.pad_token_id)
word_embs, sent_embs = lstm(tokens)
# open the image and transform it to the tensor
image = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((IMG_HW, IMG_HW)),
transforms.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)
])(image).unsqueeze(0)
# obtain visual features of the image
vgg_features = vgg(image)
local_features, global_features = inception(image)
# generate new image from the old one
fake_image, _, _ = G(noise, sent_embs, word_embs, global_features,
local_features, vgg_features, mask)
# denormalize the image
fake_image = Image.fromarray(get_image_arr(fake_image)[0])
# return image in gradio format
return fake_image
##########
# GRADIO #
##########
demo = gr.Interface(
fn=change_image_with_text,
inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))],
outputs=gr.Image(type="pil"),
examples=[
["src/data/stubs/bird.jpg", "black bird with blue wings", "Bird"],
["src/data/stubs/lady.jpg", "lady with blue eyes", "UTKFace"],
["src/data/stubs/bird.jpg", "white bird with black wings", "Bird"]
]
)
demo.launch(debug=True)