Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import os | |
from types import SimpleNamespace | |
import random | |
from torchvision.utils import save_image | |
import gradio as gr | |
import numpy as np | |
import io | |
import tempfile # Importar tempfile | |
import math | |
# Aseg煤rate de que las funciones necesarias est茅n definidas (si no lo est谩n ya) | |
def resize(img, size): | |
return F.interpolate(img, size=size, mode='bilinear', align_corners=False) | |
def denormalize(x): | |
return (x + 1) / 2 # Valores en [0, 1] | |
# Definici贸n de las clases de los modelos (Generator, StyleEncoder, MappingNetwork, ResBlk, AdaIN, AdainResBlk) | |
class ResBlk(nn.Module): | |
def __init__(self, dim_in, dim_out, normalize=False, downsample=False): | |
super().__init__() | |
self.normalize = normalize | |
self.downsample = downsample | |
self.main = nn.Sequential( | |
nn.Conv2d(dim_in, dim_out, 3, 1, 1), | |
nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity(), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(dim_out, dim_out, 3, 1, 1), | |
nn.InstanceNorm2d(dim_out, affine=True) if normalize else nn.Identity() | |
) | |
self.downsample_layer = nn.AvgPool2d(2) if downsample else nn.Identity() | |
self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) | |
def forward(self, x): | |
out = self.main(x) | |
out = self.downsample_layer(out) | |
skip = self.skip(x) | |
skip = self.downsample_layer(skip) | |
return (out + skip) / math.sqrt(2) | |
class AdaIN(nn.Module): | |
def __init__(self, num_features, style_dim): | |
super(AdaIN, self).__init__() | |
self.fc = nn.Linear(style_dim, num_features * 2) | |
def forward(self, x, s): | |
h = self.fc(s) | |
gamma, beta = torch.chunk(h, chunks=2, dim=1) | |
gamma = gamma.unsqueeze(2).unsqueeze(3) | |
beta = beta.unsqueeze(2).unsqueeze(3) | |
return (1 + gamma) * x + beta | |
class AdainResBlk(nn.Module): | |
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=1, upsample=False): | |
super().__init__() | |
self.upsample = upsample | |
self.w_hpf = w_hpf | |
self.norm1 = AdaIN(dim_in, style_dim) | |
self.norm2 = AdaIN(dim_out, style_dim) | |
self.actv = nn.LeakyReLU(0.2) | |
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) | |
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) | |
if dim_in != dim_out: | |
self.skip = nn.Conv2d(dim_in, dim_out, 1, 1, 0) | |
else: | |
self.skip = nn.Identity() | |
def forward(self, x, s): | |
x_orig = x | |
if self.upsample: | |
x = F.interpolate(x, scale_factor=2, mode='nearest') | |
x_orig = F.interpolate(x_orig, scale_factor=2, mode='nearest') | |
h = self.norm1(x, s) | |
h = self.actv(h) | |
h = self.conv1(h) | |
h = self.norm2(h, s) | |
h = self.actv(h) | |
h = self.conv2(h) | |
skip = self.skip(x_orig) | |
out = (h + skip) / math.sqrt(2) | |
return out | |
class Generator(nn.Module): | |
def __init__(self, img_size=256, style_dim=64, max_conv_dim=512): | |
super().__init__() | |
dim_in = 64 | |
blocks = [] | |
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] | |
repeat_num = int(np.log2(img_size)) - 4 | |
for _ in range(repeat_num): | |
dim_out = min(dim_in*2, max_conv_dim) | |
blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)] | |
dim_in = dim_out | |
self.encode = nn.Sequential(*blocks) | |
self.decode = nn.ModuleList() | |
for _ in range(repeat_num): | |
dim_out = dim_in // 2 | |
self.decode += [AdainResBlk(dim_in, dim_out, style_dim, upsample=True)] | |
dim_in = dim_out | |
self.to_rgb = nn.Sequential( | |
nn.InstanceNorm2d(dim_in, affine=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(dim_in, 3, 1, 1, 0) | |
) | |
def forward(self, x, s): | |
x = self.encode(x) | |
for block in self.decode: | |
x = block(x, s) | |
out = self.to_rgb(x) | |
return out | |
class MappingNetwork(nn.Module): | |
def __init__(self, latent_dim=16, style_dim=64, num_domains=2, hidden_dim=512): | |
super(MappingNetwork, self).__init__() | |
layers = [ | |
nn.Linear(latent_dim, hidden_dim), | |
nn.ReLU() | |
] | |
for _ in range(3): | |
layers += [ | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU() | |
] | |
self.shared = nn.Sequential(*layers) | |
self.unshared = nn.ModuleList() | |
for _ in range(num_domains): | |
self.unshared.append(nn.Linear(hidden_dim, style_dim)) | |
def forward(self, z, y): | |
h = self.shared(z) | |
out = [] | |
for layer in self.unshared: | |
out.append(layer(h)) | |
out = torch.stack(out, dim=1) | |
idx = torch.arange(y.size(0)).to(y.device) | |
s = out[idx, y] | |
return s | |
class StyleEncoder(nn.Module): | |
def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512): | |
super().__init__() | |
dim_in = 64 | |
blocks = [] | |
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] | |
repeat_num = int(np.log2(img_size)) - 2 | |
for _ in range(repeat_num): | |
dim_out = min(dim_in*2, max_conv_dim) | |
blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)] | |
dim_in = dim_out | |
blocks += [nn.LeakyReLU(0.2)] | |
self.shared = nn.Sequential(*blocks) | |
self.unshared = nn.ModuleList() | |
for _ in range(num_domains): | |
self.unshared += [nn.Linear(dim_in, style_dim)] | |
def forward(self, x, y): | |
h = self.shared(x) | |
h = F.adaptive_avg_pool2d(h, (1,1)) | |
h = h.view(h.size(0), -1) | |
out = [] | |
for layer in self.unshared: | |
out += [layer(h)] | |
out = torch.stack(out, dim=1) | |
idx = torch.arange(y.size(0)).to(y.device) | |
s = out[idx, y] | |
return s | |
# Clase para cargar imagenes | |
class ImageFolder(Dataset): | |
def __init__(self, root, transform, mode, which='source'): | |
self.transform = transform | |
self.paths = [] | |
domains = sorted(os.listdir(root)) | |
for domain in domains: | |
if os.path.isdir(os.path.join(root, domain)): | |
files = os.listdir(os.path.join(root, domain)) | |
files = [os.path.join(root, domain, f) for f in files] | |
self.paths += [(f, domains.index(domain)) for f in files] | |
if mode == 'train' and which == 'reference': | |
random.shuffle(self.paths) | |
def __getitem__(self, index): | |
path, label = self.paths[index] | |
img = Image.open(path).convert('RGB') | |
return self.transform(img), label | |
def __len__(self): | |
return len(self.paths) | |
# Funciones para obtener los data loaders | |
def get_transform(img_size, mode='train', prob=0.5): | |
transform = [] | |
transform.append(transforms.Resize((img_size, img_size))) | |
if mode == 'train': | |
transform.append(transforms.RandomHorizontalFlip()) | |
transform.append(transforms.RandomApply([ | |
transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)) | |
], p=prob)) | |
transform.append(transforms.ToTensor()) | |
transform.append(transforms.Normalize(mean=[0.5, 0.5, 0.5], | |
std=[0.5, 0.5, 0.5])) | |
return transforms.Compose(transform) | |
def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4): | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.RandomHorizontalFlip(p=prob), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
dataset = ImageFolder(root=root, transform=transform, mode=which) | |
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) | |
return loader | |
def get_test_loader(root, img_size=256, batch_size=8, shuffle=False, num_workers=4, mode='reference'): | |
transform = transforms.Compose([ | |
transforms.Resize((img_size, img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
dataset = ImageFolder(root=root, transform=transform, mode=mode) | |
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False) | |
return loader | |
# Clase Solver (adaptada para la inferencia) | |
class Solver(object): | |
def __init__(self, args): | |
self.args = args | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Definir los modelos | |
self.G = Generator(args.img_size, args.style_dim).to(self.device) | |
self.M = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains).to(self.device) | |
self.S = StyleEncoder(args.img_size, args.style_dim, args.num_domains).to(self.device) | |
def load_checkpoint(self, checkpoint_path): | |
try: | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
self.G.load_state_dict(checkpoint['generator']) | |
self.M.load_state_dict(checkpoint['mapping_network']) | |
self.S.load_state_dict(checkpoint['style_encoder']) | |
print(f"Checkpoint cargado exitosamente desde {checkpoint_path}.") | |
except FileNotFoundError: | |
print(f"Error: No se encontr贸 el checkpoint en {checkpoint_path}.") | |
raise FileNotFoundError(f"No se encontr贸 el checkpoint en {checkpoint_path}") | |
except Exception as e: | |
print(f"Error al cargar el checkpoint: {e}.") | |
raise Exception(f"Error al cargar el checkpoint: {e}") | |
def transfer_style(self, source_image, reference_image): | |
# Aseg煤rate de que los modelos est茅n en modo de evaluaci贸n | |
self.G.eval() | |
self.S.eval() | |
with torch.no_grad(): | |
# Preprocesar las im谩genes de entrada | |
transform = transforms.Compose([ | |
transforms.Resize((self.args.img_size, self.args.img_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
# Convertir a PIL image antes de la transformaci贸n | |
source_image_pil = Image.fromarray(source_image) | |
reference_image_pil = Image.fromarray(reference_image) | |
source_image = transform(source_image_pil).unsqueeze(0).to(self.device) | |
reference_image = transform(reference_image_pil).unsqueeze(0).to(self.device) | |
# Codificar el estilo de la imagen de referencia | |
s_ref = self.S(reference_image, torch.tensor([0]).to(self.device)) | |
# Generar la imagen con el estilo transferido | |
generated_image = self.G(source_image, s_ref) | |
# Denormalizar la imagen para mostrarla en la interfaz | |
generated_image = denormalize(generated_image.squeeze(0)).cpu() | |
return (generated_image * 255).clamp(0, 255).byte().permute(1, 2, 0).numpy() # Convertir a NumPy y a rango v谩lido | |
# Funci贸n principal para la inferencia | |
def main(source_image, reference_image, checkpoint_path, args): | |
if source_image is None or reference_image is None: | |
raise gr.Error("Por favor, proporciona ambas im谩genes (fuente y referencia).") | |
# Crear el solver | |
solver = Solver(args) | |
# Cargar el checkpoint | |
solver.load_checkpoint(checkpoint_path) | |
# Realizar la transferencia de estilo | |
generated_image = solver.transfer_style(source_image, reference_image) | |
return generated_image | |
def gradio_interface(): | |
# Definir los argumentos (ajustados para la inferencia) | |
args = SimpleNamespace( | |
img_size=128, | |
num_domains=3, | |
latent_dim=16, | |
style_dim=64, | |
num_workers=0, | |
seed=8365, | |
) | |
# Ruta al checkpoint | |
checkpoint_path = "iter/27000_nets_ema.ckpt" | |
# Crear la interfaz de Gradio | |
inputs = [ | |
gr.Image(label="Source Image (Car to change style)"), | |
gr.Image(label="Reference Image (Style to transfer)"), | |
] | |
outputs = gr.Image(label="Generated Image (Car with transferred style)") | |
title = "AutoStyleGAN: Car Style Transfer" | |
description = "Transfer the style of one car to another. Upload a source car image and a reference car image." | |
iface = gr.Interface( | |
fn=lambda source_image, reference_image: main(source_image, reference_image, checkpoint_path, args), | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
description=description, | |
) | |
return iface | |
if __name__ == '__main__': | |
iface = gradio_interface() | |
iface.launch(share=True) | |