Spaces:
Running
Running
# Hack for spaces | |
import os | |
os.system("pip uninstall -y gradio") | |
os.system("pip install -r requirements.txt") | |
# Real code begins | |
from typing import Union, List | |
import gradio as gr | |
import matplotlib | |
import torch | |
from pytorch_lightning.utilities.types import EPOCH_OUTPUT | |
matplotlib.use("Agg") | |
import numpy as np | |
from PIL import Image | |
import albumentations as A | |
import albumentations.pytorch as al_pytorch | |
import torchvision | |
from pl_bolts.models.gans import Pix2Pix | |
from pl_bolts.models.gans.pix2pix.components import PatchGAN | |
import torchvision.models as models | |
""" Class """ | |
class OverpoweredPix2Pix(Pix2Pix): | |
def validation_step(self, batch, batch_idx): | |
"""Validation step""" | |
real, condition = batch | |
with torch.no_grad(): | |
loss = self._disc_step(real, condition) | |
self.log("val_PatchGAN_loss", loss) | |
loss = self._gen_step(real, condition) | |
self.log("val_generator_loss", loss) | |
return {"sketch": real, "colour": condition} | |
def validation_epoch_end( | |
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] | |
) -> None: | |
sketch = outputs[0]["sketch"] | |
colour = outputs[0]["colour"] | |
with torch.no_grad(): | |
gen_coloured = self.gen(sketch) | |
grid_image = torchvision.utils.make_grid( | |
[ | |
sketch[0], | |
colour[0], | |
gen_coloured[0], | |
], | |
normalize=True, | |
) | |
self.logger.experiment.add_image( | |
f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch | |
) | |
class PatchGanChanged(OverpoweredPix2Pix): | |
def __init__(self, in_channels, out_channels): | |
super(PatchGanChanged, self).__init__( | |
in_channels=in_channels, out_channels=out_channels | |
) | |
self.patch_gan = self.get_dense_PatchGAN(self.patch_gan) | |
def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN: | |
"""Add final layer to gan""" | |
disc.final = torch.nn.Sequential( | |
disc.final, | |
torch.nn.Flatten(), | |
torch.nn.Linear(16 * 16, 1), | |
) | |
return disc | |
""" Load the model """ | |
# train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt" | |
train_64_val_16_plbolts_model_chkpt = ( | |
"model/lightning_bolts_model/epoch=99-step=44600.ckpt" | |
) | |
train_16_val_1_plbolts_model_chkpt = ( | |
"model/lightning_bolts_model/epoch=99-step=89000.ckpt" | |
) | |
modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt" | |
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" | |
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" | |
# Load the models | |
train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
train_64_val_16_plbolts_model_chkpt | |
) | |
train_64_val_16_plbolts_model.eval() | |
# | |
train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
train_16_val_1_plbolts_model_chkpt | |
) | |
train_16_val_1_plbolts_model.eval() | |
# | |
modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt) | |
modified_patchgan_model.eval() | |
# Create new class | |
class OverpoweredPix2Pix(Pix2Pix): | |
def __init__(self, in_channels, out_channels): | |
super(OverpoweredPix2Pix, self).__init__( | |
in_channels=in_channels, out_channels=out_channels | |
) | |
self._create_inception_score() | |
def _gen_step(self, real_images, conditioned_images): | |
# Pix2Pix has adversarial and a reconstruction loss | |
# First calculate the adversarial loss | |
fake_images = self.gen(conditioned_images) | |
disc_logits = self.patch_gan(fake_images, conditioned_images) | |
adversarial_loss = self.adversarial_criterion( | |
disc_logits, torch.ones_like(disc_logits) | |
) | |
# calculate reconstruction loss | |
recon_loss = self.recon_criterion(fake_images, real_images) | |
lambda_recon = self.hparams.lambda_recon | |
# calculate cosine similarity | |
representations_real = self.feature_extractor(real_images).flatten(1) | |
representations_fake = self.feature_extractor(fake_images).flatten(1) | |
similarity_score_list = self.cosine_similarity( | |
representations_real, representations_fake | |
) | |
cosine_sim = sum(similarity_score_list) / len(similarity_score_list) | |
self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy()) | |
# print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, ) | |
return ( | |
(adversarial_loss) | |
+ (lambda_recon * recon_loss) | |
+ (lambda_recon * (1 - cosine_sim)) | |
) | |
def _create_inception_score(self): | |
# init a pretrained resnet | |
backbone = models.resnet50(pretrained=True) | |
num_filters = backbone.fc.in_features | |
layers = list(backbone.children())[:-1] | |
self.feature_extractor = torch.nn.Sequential(*layers) | |
self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
def validation_step(self, batch, batch_idx): | |
"""Validation step""" | |
real, condition = batch | |
with torch.no_grad(): | |
disc_loss = self._disc_step(real, condition) | |
self.log("Valid PatchGAN Loss", disc_loss) | |
gan_loss = self._gen_step(real, condition) | |
self.log("Valid Generator Loss", gan_loss) | |
# | |
fake_images = self.gen(condition) | |
representations_real = self.feature_extractor(real).flatten(1) | |
representations_fake = self.feature_extractor(fake_images).flatten(1) | |
similarity_score_list = self.cosine_similarity( | |
representations_real, representations_fake | |
) | |
cosine_sim = sum(similarity_score_list) / len(similarity_score_list) | |
self.log("Valid Cosine Sim", cosine_sim) | |
return {"sketch": condition, "colour": real} | |
def validation_epoch_end( | |
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] | |
) -> None: | |
sketch = outputs[0]["sketch"] | |
colour = outputs[0]["colour"] | |
self.feature_extractor.eval() | |
with torch.no_grad(): | |
gen_coloured = self.gen(sketch) | |
representations_gen = self.feature_extractor(gen_coloured).flatten(1) | |
representations_fake = self.feature_extractor(colour).flatten(1) | |
similarity_score_list = self.cosine_similarity( | |
representations_gen, representations_fake | |
) | |
similarity_score = sum(similarity_score_list) / len(similarity_score_list) | |
grid_image = torchvision.utils.make_grid( | |
[ | |
sketch[0], | |
colour[0], | |
gen_coloured[0], | |
], | |
normalize=True, | |
) | |
self.logger.experiment.add_image( | |
f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ", | |
grid_image, | |
self.current_epoch, | |
) | |
cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt" | |
cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt) | |
cosine_sim_model.eval() | |
def predict(img: Image, type_of_model: str): | |
"""Create predictions""" | |
# transform img | |
image = np.asarray(img) | |
# use on inference | |
inference_transform = A.Compose( | |
[ | |
A.Resize(width=256, height=256), | |
A.Normalize( | |
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0 | |
), | |
al_pytorch.ToTensorV2(), | |
] | |
) | |
inference_img = inference_transform(image=image)["image"].unsqueeze(0) | |
# Choose model | |
if type_of_model == "train batch size 16, val batch size 1": | |
model = train_16_val_1_plbolts_model | |
elif type_of_model == "train batch size 64, val batch size 16": | |
model = train_64_val_16_plbolts_model | |
elif type_of_model == "cosine similarity": | |
model = cosine_sim_model | |
else: | |
model = modified_patchgan_model | |
with torch.no_grad(): | |
result = model.gen(inference_img) | |
torchvision.utils.save_image(result, "inference_image.png", normalize=True) | |
return "inference_image.png" # 'coloured_image.png', | |
def predict1(img: Image): | |
return predict(img=img, type_of_model="train batch size 16, val batch size 1") | |
def predict2(img: Image): | |
return predict(img=img, type_of_model="train batch size 64, val batch size 16") | |
def predict3(img: Image): | |
return predict( | |
img=img, | |
type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", | |
) | |
def predict4(img: Image): | |
return predict( | |
img=img, | |
type_of_model="cosine similarity", | |
) | |
model_input = gr.inputs.Radio( | |
[ | |
"train batch size 16, val batch size 1", | |
"train batch size 64, val batch size 16", | |
"train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", | |
], | |
label="Type of Pix2Pix model to use : ", | |
) | |
image_input = gr.inputs.Image(type="pil") | |
img_examples = [ | |
"examples/thesis_test.png", | |
"examples/thesis_test2.png", | |
"examples/thesis1.png", | |
"examples/thesis4.png", | |
"examples/thesis5.png", | |
"examples/thesis6.png", | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown(" # Colour your sketches!") | |
gr.Markdown(" ## Description :") | |
gr.Markdown(" There are 4 Pix2Pix models in this example:") | |
gr.Markdown(" 1. Training batch size is 16 , validation is 1") | |
gr.Markdown(" 2. Training batch size is 64 , validation is 16") | |
gr.Markdown( | |
" 3. PatchGAN is changed, 1 value only instead of 16*16 ;" | |
"training batch size is 64 , validation is 16" | |
) | |
gr.Markdown( | |
" 4. cosine similarity is also added as a metric in this experiment for the generator. " | |
) | |
with gr.Tabs(): | |
with gr.TabItem("tr_16_val_1"): | |
with gr.Row(): | |
image_input1 = gr.inputs.Image(type="pil") | |
image_output1 = gr.outputs.Image( | |
type="pil", | |
) | |
colour_1 = gr.Button("Colour it!") | |
gr.Examples( | |
examples=img_examples, | |
inputs=image_input1, | |
outputs=image_output1, | |
fn=predict1, | |
) | |
with gr.TabItem("tr_64_val_14"): | |
with gr.Row(): | |
image_input2 = gr.inputs.Image(type="pil") | |
image_output2 = gr.outputs.Image( | |
type="pil", | |
) | |
colour_2 = gr.Button("Colour it!") | |
with gr.Row(): | |
gr.Examples( | |
examples=img_examples, | |
inputs=image_input2, | |
outputs=image_output2, | |
fn=predict2, | |
) | |
with gr.TabItem("Single Value Discriminator"): | |
with gr.Row(): | |
image_input3 = gr.inputs.Image(type="pil") | |
image_output3 = gr.outputs.Image( | |
type="pil", | |
) | |
colour_3 = gr.Button("Colour it!") | |
with gr.Row(): | |
gr.Examples( | |
examples=img_examples, | |
inputs=image_input3, | |
outputs=image_output3, | |
fn=predict3, | |
) | |
with gr.TabItem("Cosine similarity loss"): | |
with gr.Row(): | |
image_input4 = gr.inputs.Image(type="pil") | |
image_output4 = gr.outputs.Image( | |
type="pil", | |
) | |
colour_4 = gr.Button("Colour it!") | |
with gr.Row(): | |
gr.Examples( | |
examples=img_examples, | |
inputs=image_input4, | |
outputs=image_output4, | |
fn=predict4, | |
) | |
colour_1.click( | |
fn=predict1, | |
inputs=image_input1, | |
outputs=image_output1, | |
) | |
colour_2.click( | |
fn=predict2, | |
inputs=image_input2, | |
outputs=image_output2, | |
) | |
colour_3.click( | |
fn=predict3, | |
inputs=image_input3, | |
outputs=image_output3, | |
) | |
colour_4.click( | |
fn=predict4, | |
inputs=image_input4, | |
outputs=image_output4, | |
) | |
demo.title = "Colour your sketches!" | |
demo.launch() | |