AutoStyleGAN / app.py
tincri's picture
checkpoint update
6e8f07c
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from types import SimpleNamespace
import random
from torchvision.utils import save_image
import gradio as gr
import numpy as np
import io
import tempfile # Importar tempfile
import math
# Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya)
def resize(img, size):
return F.interpolate(img, size=size, mode='bilinear', align_corners=False)
def denormalize(x):
return (x + 1) / 2 # Valores en [0, 1]
# Definici贸n de las clases de los modelos (Generator, StyleEncoder, MappingNetwork, ResBlk, AdaIN, AdainResBlk)
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, normalize=False, downsample=False):
super().__init__()
self.normalize = normalize
self.downsample = downsample
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, 1),
nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, 3, 1, 1),
nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity()
)
self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity()
self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def forward(self, x):
out = self.main(x)
out = self.downsample_layer(out)
skip = self.skip(x)
skip = self.downsample_layer(skip)
return (out + skip) / math.sqrt(2)
class AdaIN(nn.Module):
def __init__(self, num_features, style_dim):
super(AdaIN, self).__init__()
self.fc = nn.Linear(style_dim, num_features * 2)
def forward(self, x, s):
h = self.fc(s)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma = gamma.unsqueeze(2).unsqueeze(3)
beta = beta.unsqueeze(2).unsqueeze(3)
return (1 + gamma) * x + beta
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False):
super().__init__()
self.upsample = upsample
self.w_hpf = w_hpf
self.norm1 = AdaIN(dim_in, style_dim)
self.norm2 = AdaIN(dim_out, style_dim)
self.actv = nn.LeakyReLU(0.2)
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
if dim_in != dim_out:
self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0)
else:
self.skip = nn.Identity()
def forward(self, x, s):
x_orig = x
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest')
h = self.norm1(x, s)
h = self.actv(h)
h = self.conv1(h)
h = self.norm2(h, s)
h = self.actv(h)
h = self.conv2(h)
skip = self.skip(x_orig)
out = (h + skip) / math.sqrt(2)
return out
class Generator(nn.Module):
def __init__(self, img_size=256, style_dim=64, max_conv_dim=512):
super().__init__()
dim_in = 64
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 4
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
dim_in = dim_out
self.encode = nn.Sequential(*blocks)
self.decode = nn.ModuleList()
for _ in range(repeat_num):
dim_out = dim_in // 2
self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)]
dim_in = dim_out
self.to_rgb = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_in, 3, 1, 1, 0)
)
def forward(self, x, s):
x = self.encode(x)
for block in self.decode:
x = block(x, s)
out = self.to_rgb(x)
return out
class MappingNetwork(nn.Module):
def __init__(self, latent_dim=16, style_dim=64, num_domains=2, hidden_dim=512):
super(MappingNetwork, self).__init__()
layers = [
nn.Linear(latent_dim, hidden_dim),
nn.ReLU()
]
for _ in range(3):
layers += [
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
]
self.shared = nn.Sequential(*layers)
self.unshared = nn.ModuleList()
for _ in range(num_domains):
self.unshared.append(nn.Linear(hidden_dim, style_dim))
def forward(self, z, y):
h = self.shared(z)
out = []
for layer in self.unshared:
out.append(layer(h))
out = torch.stack(out, dim=1)
idx = torch.arange(y.size(0)).to(y.device)
s = out[idx, y]
return s
class StyleEncoder(nn.Module):
def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
super().__init__()
dim_in = 64
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
self.shared = nn.Sequential(*blocks)
self.unshared = nn.ModuleList()
for _ in range(num_domains):
self.unshared += [nn.Linear(dim_in, style_dim)]
def forward(self, x, y):
h = self.shared(x)
h = F.adaptive_avg_pool2d(h, (1,1))
h = h.view(h.size(0), -1)
out = []
for layer in self.unshared:
out += [layer(h)]
out = torch.stack(out, dim=1)
idx = torch.arange(y.size(0)).to(y.device)
s = out[idx, y]
return s
# Clase para cargar imagenes
class ImageFolder(Dataset):
def __init__(self, root, transform, mode, which='source'):
self.transform = transform
self.paths = []
domains = sorted(os.listdir(root))
for domain in domains:
if os.path.isdir(os.path.join(root, domain)):
files = os.listdir(os.path.join(root, domain))
files = [os.path.join(root, domain, f) for f in files]
self.paths += [(f, domains.index(domain)) for f in files]
if mode == 'train' and which == 'reference':
random.shuffle(self.paths)
def __getitem__(self, index):
path, label = self.paths[index]
img = Image.open(path).convert('RGB')
return self.transform(img), label
def __len__(self):
return len(self.paths)
# Funciones para obtener los data loaders
def get_transform(img_size, mode='train', prob=0.5):
transform = []
transform.append(transforms.Resize((img_size, img_size)))
if mode == 'train':
transform.append(transforms.RandomHorizontalFlip())
transform.append(transforms.RandomApply([
transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0))
], p=prob))
transform.append(transforms.ToTensor())
transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]))
return transforms.Compose(transform)
def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(p=prob),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageFolder(root=root, transform=transform, mode=which)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
return loader
def get_test_loader(root, img_size=256, batch_size=8, shuffle=False, num_workers=4, mode='reference'):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageFolder(root=root, transform=transform, mode=mode)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False)
return loader
# Clase Solver (adaptada para la inferencia)
class Solver(object):
def __init__(self, args):
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Definir los modelos
self.G = Generator(args.img_size, args.style_dim).to(self.device)
self.M = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains).to(self.device)
self.S = StyleEncoder(args.img_size, args.style_dim, args.num_domains).to(self.device)
def load_checkpoint(self, checkpoint_path):
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.G.load_state_dict(checkpoint['generator'])
self.M.load_state_dict(checkpoint['mapping_network'])
self.S.load_state_dict(checkpoint['style_encoder'])
print(f"Checkpoint cargado exitosamente desde {checkpoint_path}.")
except FileNotFoundError:
print(f"Error: No se encontr贸 el checkpoint en {checkpoint_path}.")
raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}")
except Exception as e:
print(f"Error al cargar el checkpoint: {e}.")
raise Exception(f"Error al cargar el checkpoint: {e}")
def transfer_style(self, source_image, reference_image):
# Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n
self.G.eval()
self.S.eval()
with torch.no_grad():
# Preprocesar las im谩genes de entrada
transform = transforms.Compose([
transforms.Resize((self.args.img_size, self.args.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Convertir a PIL image antes de la transformaci贸n
source_image_pil = Image.fromarray(source_image)
reference_image_pil = Image.fromarray(reference_image)
source_image = transform(source_image_pil).unsqueeze(0).to(self.device)
reference_image = transform(reference_image_pil).unsqueeze(0).to(self.device)
# Codificar el estilo de la imagen de referencia
s_ref = self.S(reference_image, torch.tensor([0]).to(self.device))
# Generar la imagen con el estilo transferido
generated_image = self.G(source_image, s_ref)
# Denormalizar la imagen para mostrarla en la interfaz
generated_image = denormalize(generated_image.squeeze(0)).cpu()
return (generated_image * 255).clamp(0, 255).byte().permute(1, 2, 0).numpy() # Convertir a NumPy y a rango v谩lido
# Funci贸n principal para la inferencia
def main(source_image, reference_image, checkpoint_path, args):
if source_image is None or reference_image is None:
raise gr.Error("Por favor, proporciona ambas im谩genes (fuente y referencia).")
# Crear el solver
solver = Solver(args)
# Cargar el checkpoint
solver.load_checkpoint(checkpoint_path)
# Realizar la transferencia de estilo
generated_image = solver.transfer_style(source_image, reference_image)
return generated_image
def gradio_interface():
# Definir los argumentos (ajustados para la inferencia)
args = SimpleNamespace(
img_size=128,
num_domains=3,
latent_dim=16,
style_dim=64,
num_workers=0,
seed=8365,
)
# Ruta al checkpoint
checkpoint_path = "iter/27000_nets_ema.ckpt"
# Crear la interfaz de Gradio
inputs = [
gr.Image(label="Source Image (Car to change style)"),
gr.Image(label="Reference Image (Style to transfer)"),
]
outputs = gr.Image(label="Generated Image (Car with transferred style)")
title = "AutoStyleGAN: Car Style Transfer"
description = "Transfer the style of one car to another. Upload a source car image and a reference car image."
iface = gr.Interface(
fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args),
inputs=inputs,
outputs=outputs,
title=title,
description=description,
)
return iface
if __name__ == '__main__':
iface = gradio_interface()
iface.launch(share=True)