|
import numpy as np |
|
from src.models.utils import get_image_arr, load_model |
|
from src.data import TAIMGANTokenizer |
|
from torchvision import transforms |
|
from src.config import config_dict |
|
from pathlib import Path |
|
from enum import IntEnum, auto |
|
from PIL import Image |
|
import gradio as gr |
|
import torch |
|
from src.models.modules import ( |
|
VGGEncoder, |
|
InceptionEncoder, |
|
TextEncoder, |
|
Generator |
|
) |
|
|
|
|
|
|
|
|
|
|
|
IMG_CHANS = 3 |
|
IMG_HW = 256 |
|
HIDDEN_DIM = 128 |
|
C = 2 * HIDDEN_DIM |
|
|
|
Ng = config_dict["Ng"] |
|
cond_dim = config_dict["condition_dim"] |
|
z_dim = config_dict["noise_dim"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
models = { |
|
"COCO": { |
|
"dir": "weights/coco" |
|
}, |
|
"Bird": { |
|
"dir": "weights/bird" |
|
}, |
|
"UTKFace": { |
|
"dir": "weights/utkface" |
|
} |
|
} |
|
|
|
for model_name in models: |
|
|
|
models[model_name]["tokenizer"] = TAIMGANTokenizer(captions_path=f"{models[model_name]['dir']}/captions.pickle") |
|
vocab_size = len(models[model_name]["tokenizer"].word_to_ix) |
|
|
|
models[model_name]["generator"] = Generator(Ng=Ng, D=C, conditioning_dim=cond_dim, noise_dim=z_dim).eval() |
|
models[model_name]["lstm"] = TextEncoder(vocab_size=vocab_size, emb_dim=C, hidden_dim=HIDDEN_DIM).eval() |
|
models[model_name]["vgg"] = VGGEncoder().eval() |
|
models[model_name]["inception"] = InceptionEncoder(D=C).eval() |
|
|
|
load_model( |
|
generator=models[model_name]["generator"], |
|
discriminator=None, |
|
image_encoder=models[model_name]["inception"], |
|
text_encoder=models[model_name]["lstm"], |
|
output_dir=Path(models[model_name]["dir"]), |
|
device=torch.device("cpu") |
|
) |
|
|
|
|
|
def change_image_with_text(image: Image, text: str, model_name: str) -> Image: |
|
""" |
|
Create an image modified by text from the original image |
|
and save it with _modified postfix |
|
|
|
:param gr.Image image: Path to the image |
|
:param str text: Desired caption |
|
""" |
|
global models |
|
tokenizer = models[model_name]["tokenizer"] |
|
G = models[model_name]["generator"] |
|
lstm = models[model_name]["lstm"] |
|
inception = models[model_name]["inception"] |
|
vgg = models[model_name]["vgg"] |
|
|
|
noise = torch.rand(z_dim).unsqueeze(0) |
|
|
|
tokens = torch.tensor(tokenizer.encode(text)).unsqueeze(0) |
|
mask = (tokens == tokenizer.pad_token_id) |
|
word_embs, sent_embs = lstm(tokens) |
|
|
|
image = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((IMG_HW, IMG_HW)), |
|
transforms.Normalize( |
|
mean=(0.5, 0.5, 0.5), |
|
std=(0.5, 0.5, 0.5) |
|
) |
|
])(image).unsqueeze(0) |
|
|
|
vgg_features = vgg(image) |
|
local_features, global_features = inception(image) |
|
|
|
fake_image, _, _ = G(noise, sent_embs, word_embs, global_features, |
|
local_features, vgg_features, mask) |
|
|
|
fake_image = Image.fromarray(get_image_arr(fake_image)[0]) |
|
|
|
return fake_image |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=change_image_with_text, |
|
inputs=[gr.Image(type="pil"), "text", gr.inputs.Dropdown(list(models.keys()))], |
|
outputs=gr.Image(type="pil"), |
|
examples=[ |
|
["src/data/stubs/car.jpeg", "black car on the green road", "COCO"], |
|
["src/data/stubs/lady.jpg", "lady with blue eyes", "UTKFace"], |
|
["src/data/stubs/bird.jpg", "white bird with black wings", "Bird"] |
|
] |
|
) |
|
demo.launch(debug=True) |
|
|