Nikhil Mudhalwadkar
added new model with cosine similarity
7337bea
# Hack for spaces
import os
os.system("pip uninstall -y gradio")
os.system("pip install -r requirements.txt")
# Real code begins
from typing import Union, List
import gradio as gr
import matplotlib
import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
matplotlib.use("Agg")
import numpy as np
from PIL import Image
import albumentations as A
import albumentations.pytorch as al_pytorch
import torchvision
from pl_bolts.models.gans import Pix2Pix
from pl_bolts.models.gans.pix2pix.components import PatchGAN
import torchvision.models as models
""" Class """
class OverpoweredPix2Pix(Pix2Pix):
def validation_step(self, batch, batch_idx):
"""Validation step"""
real, condition = batch
with torch.no_grad():
loss = self._disc_step(real, condition)
self.log("val_PatchGAN_loss", loss)
loss = self._gen_step(real, condition)
self.log("val_generator_loss", loss)
return {"sketch": real, "colour": condition}
def validation_epoch_end(
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
) -> None:
sketch = outputs[0]["sketch"]
colour = outputs[0]["colour"]
with torch.no_grad():
gen_coloured = self.gen(sketch)
grid_image = torchvision.utils.make_grid(
[
sketch[0],
colour[0],
gen_coloured[0],
],
normalize=True,
)
self.logger.experiment.add_image(
f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch
)
class PatchGanChanged(OverpoweredPix2Pix):
def __init__(self, in_channels, out_channels):
super(PatchGanChanged, self).__init__(
in_channels=in_channels, out_channels=out_channels
)
self.patch_gan = self.get_dense_PatchGAN(self.patch_gan)
@staticmethod
def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN:
"""Add final layer to gan"""
disc.final = torch.nn.Sequential(
disc.final,
torch.nn.Flatten(),
torch.nn.Linear(16 * 16, 1),
)
return disc
""" Load the model """
# train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt"
train_64_val_16_plbolts_model_chkpt = (
"model/lightning_bolts_model/epoch=99-step=44600.ckpt"
)
train_16_val_1_plbolts_model_chkpt = (
"model/lightning_bolts_model/epoch=99-step=89000.ckpt"
)
modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt"
# model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth"
# Load the models
train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
train_64_val_16_plbolts_model_chkpt
)
train_64_val_16_plbolts_model.eval()
#
train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint(
train_16_val_1_plbolts_model_chkpt
)
train_16_val_1_plbolts_model.eval()
#
modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt)
modified_patchgan_model.eval()
# Create new class
class OverpoweredPix2Pix(Pix2Pix):
def __init__(self, in_channels, out_channels):
super(OverpoweredPix2Pix, self).__init__(
in_channels=in_channels, out_channels=out_channels
)
self._create_inception_score()
def _gen_step(self, real_images, conditioned_images):
# Pix2Pix has adversarial and a reconstruction loss
# First calculate the adversarial loss
fake_images = self.gen(conditioned_images)
disc_logits = self.patch_gan(fake_images, conditioned_images)
adversarial_loss = self.adversarial_criterion(
disc_logits, torch.ones_like(disc_logits)
)
# calculate reconstruction loss
recon_loss = self.recon_criterion(fake_images, real_images)
lambda_recon = self.hparams.lambda_recon
# calculate cosine similarity
representations_real = self.feature_extractor(real_images).flatten(1)
representations_fake = self.feature_extractor(fake_images).flatten(1)
similarity_score_list = self.cosine_similarity(
representations_real, representations_fake
)
cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy())
# print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, )
return (
(adversarial_loss)
+ (lambda_recon * recon_loss)
+ (lambda_recon * (1 - cosine_sim))
)
def _create_inception_score(self):
# init a pretrained resnet
backbone = models.resnet50(pretrained=True)
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = torch.nn.Sequential(*layers)
self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
def validation_step(self, batch, batch_idx):
"""Validation step"""
real, condition = batch
with torch.no_grad():
disc_loss = self._disc_step(real, condition)
self.log("Valid PatchGAN Loss", disc_loss)
gan_loss = self._gen_step(real, condition)
self.log("Valid Generator Loss", gan_loss)
#
fake_images = self.gen(condition)
representations_real = self.feature_extractor(real).flatten(1)
representations_fake = self.feature_extractor(fake_images).flatten(1)
similarity_score_list = self.cosine_similarity(
representations_real, representations_fake
)
cosine_sim = sum(similarity_score_list) / len(similarity_score_list)
self.log("Valid Cosine Sim", cosine_sim)
return {"sketch": condition, "colour": real}
def validation_epoch_end(
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
) -> None:
sketch = outputs[0]["sketch"]
colour = outputs[0]["colour"]
self.feature_extractor.eval()
with torch.no_grad():
gen_coloured = self.gen(sketch)
representations_gen = self.feature_extractor(gen_coloured).flatten(1)
representations_fake = self.feature_extractor(colour).flatten(1)
similarity_score_list = self.cosine_similarity(
representations_gen, representations_fake
)
similarity_score = sum(similarity_score_list) / len(similarity_score_list)
grid_image = torchvision.utils.make_grid(
[
sketch[0],
colour[0],
gen_coloured[0],
],
normalize=True,
)
self.logger.experiment.add_image(
f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ",
grid_image,
self.current_epoch,
)
cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt"
cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt)
cosine_sim_model.eval()
def predict(img: Image, type_of_model: str):
"""Create predictions"""
# transform img
image = np.asarray(img)
# use on inference
inference_transform = A.Compose(
[
A.Resize(width=256, height=256),
A.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0
),
al_pytorch.ToTensorV2(),
]
)
inference_img = inference_transform(image=image)["image"].unsqueeze(0)
# Choose model
if type_of_model == "train batch size 16, val batch size 1":
model = train_16_val_1_plbolts_model
elif type_of_model == "train batch size 64, val batch size 16":
model = train_64_val_16_plbolts_model
elif type_of_model == "cosine similarity":
model = cosine_sim_model
else:
model = modified_patchgan_model
with torch.no_grad():
result = model.gen(inference_img)
torchvision.utils.save_image(result, "inference_image.png", normalize=True)
return "inference_image.png" # 'coloured_image.png',
def predict1(img: Image):
return predict(img=img, type_of_model="train batch size 16, val batch size 1")
def predict2(img: Image):
return predict(img=img, type_of_model="train batch size 64, val batch size 16")
def predict3(img: Image):
return predict(
img=img,
type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
)
def predict4(img: Image):
return predict(
img=img,
type_of_model="cosine similarity",
)
model_input = gr.inputs.Radio(
[
"train batch size 16, val batch size 1",
"train batch size 64, val batch size 16",
"train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16",
],
label="Type of Pix2Pix model to use : ",
)
image_input = gr.inputs.Image(type="pil")
img_examples = [
"examples/thesis_test.png",
"examples/thesis_test2.png",
"examples/thesis1.png",
"examples/thesis4.png",
"examples/thesis5.png",
"examples/thesis6.png",
]
with gr.Blocks() as demo:
gr.Markdown(" # Colour your sketches!")
gr.Markdown(" ## Description :")
gr.Markdown(" There are 4 Pix2Pix models in this example:")
gr.Markdown(" 1. Training batch size is 16 , validation is 1")
gr.Markdown(" 2. Training batch size is 64 , validation is 16")
gr.Markdown(
" 3. PatchGAN is changed, 1 value only instead of 16*16 ;"
"training batch size is 64 , validation is 16"
)
gr.Markdown(
" 4. cosine similarity is also added as a metric in this experiment for the generator. "
)
with gr.Tabs():
with gr.TabItem("tr_16_val_1"):
with gr.Row():
image_input1 = gr.inputs.Image(type="pil")
image_output1 = gr.outputs.Image(
type="pil",
)
colour_1 = gr.Button("Colour it!")
gr.Examples(
examples=img_examples,
inputs=image_input1,
outputs=image_output1,
fn=predict1,
)
with gr.TabItem("tr_64_val_14"):
with gr.Row():
image_input2 = gr.inputs.Image(type="pil")
image_output2 = gr.outputs.Image(
type="pil",
)
colour_2 = gr.Button("Colour it!")
with gr.Row():
gr.Examples(
examples=img_examples,
inputs=image_input2,
outputs=image_output2,
fn=predict2,
)
with gr.TabItem("Single Value Discriminator"):
with gr.Row():
image_input3 = gr.inputs.Image(type="pil")
image_output3 = gr.outputs.Image(
type="pil",
)
colour_3 = gr.Button("Colour it!")
with gr.Row():
gr.Examples(
examples=img_examples,
inputs=image_input3,
outputs=image_output3,
fn=predict3,
)
with gr.TabItem("Cosine similarity loss"):
with gr.Row():
image_input4 = gr.inputs.Image(type="pil")
image_output4 = gr.outputs.Image(
type="pil",
)
colour_4 = gr.Button("Colour it!")
with gr.Row():
gr.Examples(
examples=img_examples,
inputs=image_input4,
outputs=image_output4,
fn=predict4,
)
colour_1.click(
fn=predict1,
inputs=image_input1,
outputs=image_output1,
)
colour_2.click(
fn=predict2,
inputs=image_input2,
outputs=image_output2,
)
colour_3.click(
fn=predict3,
inputs=image_input3,
outputs=image_output3,
)
colour_4.click(
fn=predict4,
inputs=image_input4,
outputs=image_output4,
)
demo.title = "Colour your sketches!"
demo.launch()