# Hack for spaces import os os.system("pip uninstall -y gradio") os.system("pip install -r requirements.txt") # Real code begins from typing import Union, List import gradio as gr import matplotlib import torch from pytorch_lightning.utilities.types import EPOCH_OUTPUT matplotlib.use("Agg") import numpy as np from PIL import Image import albumentations as A import albumentations.pytorch as al_pytorch import torchvision from pl_bolts.models.gans import Pix2Pix from pl_bolts.models.gans.pix2pix.components import PatchGAN import torchvision.models as models """ Class """ class OverpoweredPix2Pix(Pix2Pix): def validation_step(self, batch, batch_idx): """Validation step""" real, condition = batch with torch.no_grad(): loss = self._disc_step(real, condition) self.log("val_PatchGAN_loss", loss) loss = self._gen_step(real, condition) self.log("val_generator_loss", loss) return {"sketch": real, "colour": condition} def validation_epoch_end( self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] ) -> None: sketch = outputs[0]["sketch"] colour = outputs[0]["colour"] with torch.no_grad(): gen_coloured = self.gen(sketch) grid_image = torchvision.utils.make_grid( [ sketch[0], colour[0], gen_coloured[0], ], normalize=True, ) self.logger.experiment.add_image( f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch ) class PatchGanChanged(OverpoweredPix2Pix): def __init__(self, in_channels, out_channels): super(PatchGanChanged, self).__init__( in_channels=in_channels, out_channels=out_channels ) self.patch_gan = self.get_dense_PatchGAN(self.patch_gan) @staticmethod def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN: """Add final layer to gan""" disc.final = torch.nn.Sequential( disc.final, torch.nn.Flatten(), torch.nn.Linear(16 * 16, 1), ) return disc """ Load the model """ # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt" train_64_val_16_plbolts_model_chkpt = ( "model/lightning_bolts_model/epoch=99-step=44600.ckpt" ) train_16_val_1_plbolts_model_chkpt = ( "model/lightning_bolts_model/epoch=99-step=89000.ckpt" ) modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" # Load the models train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( train_64_val_16_plbolts_model_chkpt ) train_64_val_16_plbolts_model.eval() # train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( train_16_val_1_plbolts_model_chkpt ) train_16_val_1_plbolts_model.eval() # modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt) modified_patchgan_model.eval() # Create new class class OverpoweredPix2Pix(Pix2Pix): def __init__(self, in_channels, out_channels): super(OverpoweredPix2Pix, self).__init__( in_channels=in_channels, out_channels=out_channels ) self._create_inception_score() def _gen_step(self, real_images, conditioned_images): # Pix2Pix has adversarial and a reconstruction loss # First calculate the adversarial loss fake_images = self.gen(conditioned_images) disc_logits = self.patch_gan(fake_images, conditioned_images) adversarial_loss = self.adversarial_criterion( disc_logits, torch.ones_like(disc_logits) ) # calculate reconstruction loss recon_loss = self.recon_criterion(fake_images, real_images) lambda_recon = self.hparams.lambda_recon # calculate cosine similarity representations_real = self.feature_extractor(real_images).flatten(1) representations_fake = self.feature_extractor(fake_images).flatten(1) similarity_score_list = self.cosine_similarity( representations_real, representations_fake ) cosine_sim = sum(similarity_score_list) / len(similarity_score_list) self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy()) # print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, ) return ( (adversarial_loss) + (lambda_recon * recon_loss) + (lambda_recon * (1 - cosine_sim)) ) def _create_inception_score(self): # init a pretrained resnet backbone = models.resnet50(pretrained=True) num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = torch.nn.Sequential(*layers) self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) def validation_step(self, batch, batch_idx): """Validation step""" real, condition = batch with torch.no_grad(): disc_loss = self._disc_step(real, condition) self.log("Valid PatchGAN Loss", disc_loss) gan_loss = self._gen_step(real, condition) self.log("Valid Generator Loss", gan_loss) # fake_images = self.gen(condition) representations_real = self.feature_extractor(real).flatten(1) representations_fake = self.feature_extractor(fake_images).flatten(1) similarity_score_list = self.cosine_similarity( representations_real, representations_fake ) cosine_sim = sum(similarity_score_list) / len(similarity_score_list) self.log("Valid Cosine Sim", cosine_sim) return {"sketch": condition, "colour": real} def validation_epoch_end( self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] ) -> None: sketch = outputs[0]["sketch"] colour = outputs[0]["colour"] self.feature_extractor.eval() with torch.no_grad(): gen_coloured = self.gen(sketch) representations_gen = self.feature_extractor(gen_coloured).flatten(1) representations_fake = self.feature_extractor(colour).flatten(1) similarity_score_list = self.cosine_similarity( representations_gen, representations_fake ) similarity_score = sum(similarity_score_list) / len(similarity_score_list) grid_image = torchvision.utils.make_grid( [ sketch[0], colour[0], gen_coloured[0], ], normalize=True, ) self.logger.experiment.add_image( f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ", grid_image, self.current_epoch, ) cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt" cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt) cosine_sim_model.eval() def predict(img: Image, type_of_model: str): """Create predictions""" # transform img image = np.asarray(img) # use on inference inference_transform = A.Compose( [ A.Resize(width=256, height=256), A.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0 ), al_pytorch.ToTensorV2(), ] ) inference_img = inference_transform(image=image)["image"].unsqueeze(0) # Choose model if type_of_model == "train batch size 16, val batch size 1": model = train_16_val_1_plbolts_model elif type_of_model == "train batch size 64, val batch size 16": model = train_64_val_16_plbolts_model elif type_of_model == "cosine similarity": model = cosine_sim_model else: model = modified_patchgan_model with torch.no_grad(): result = model.gen(inference_img) torchvision.utils.save_image(result, "inference_image.png", normalize=True) return "inference_image.png" # 'coloured_image.png', def predict1(img: Image): return predict(img=img, type_of_model="train batch size 16, val batch size 1") def predict2(img: Image): return predict(img=img, type_of_model="train batch size 64, val batch size 16") def predict3(img: Image): return predict( img=img, type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", ) def predict4(img: Image): return predict( img=img, type_of_model="cosine similarity", ) model_input = gr.inputs.Radio( [ "train batch size 16, val batch size 1", "train batch size 64, val batch size 16", "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", ], label="Type of Pix2Pix model to use : ", ) image_input = gr.inputs.Image(type="pil") img_examples = [ "examples/thesis_test.png", "examples/thesis_test2.png", "examples/thesis1.png", "examples/thesis4.png", "examples/thesis5.png", "examples/thesis6.png", ] with gr.Blocks() as demo: gr.Markdown(" # Colour your sketches!") gr.Markdown(" ## Description :") gr.Markdown(" There are 4 Pix2Pix models in this example:") gr.Markdown(" 1. Training batch size is 16 , validation is 1") gr.Markdown(" 2. Training batch size is 64 , validation is 16") gr.Markdown( " 3. PatchGAN is changed, 1 value only instead of 16*16 ;" "training batch size is 64 , validation is 16" ) gr.Markdown( " 4. cosine similarity is also added as a metric in this experiment for the generator. " ) with gr.Tabs(): with gr.TabItem("tr_16_val_1"): with gr.Row(): image_input1 = gr.inputs.Image(type="pil") image_output1 = gr.outputs.Image( type="pil", ) colour_1 = gr.Button("Colour it!") gr.Examples( examples=img_examples, inputs=image_input1, outputs=image_output1, fn=predict1, ) with gr.TabItem("tr_64_val_14"): with gr.Row(): image_input2 = gr.inputs.Image(type="pil") image_output2 = gr.outputs.Image( type="pil", ) colour_2 = gr.Button("Colour it!") with gr.Row(): gr.Examples( examples=img_examples, inputs=image_input2, outputs=image_output2, fn=predict2, ) with gr.TabItem("Single Value Discriminator"): with gr.Row(): image_input3 = gr.inputs.Image(type="pil") image_output3 = gr.outputs.Image( type="pil", ) colour_3 = gr.Button("Colour it!") with gr.Row(): gr.Examples( examples=img_examples, inputs=image_input3, outputs=image_output3, fn=predict3, ) with gr.TabItem("Cosine similarity loss"): with gr.Row(): image_input4 = gr.inputs.Image(type="pil") image_output4 = gr.outputs.Image( type="pil", ) colour_4 = gr.Button("Colour it!") with gr.Row(): gr.Examples( examples=img_examples, inputs=image_input4, outputs=image_output4, fn=predict4, ) colour_1.click( fn=predict1, inputs=image_input1, outputs=image_output1, ) colour_2.click( fn=predict2, inputs=image_input2, outputs=image_output2, ) colour_3.click( fn=predict3, inputs=image_input3, outputs=image_output3, ) colour_4.click( fn=predict4, inputs=image_input4, outputs=image_output4, ) demo.title = "Colour your sketches!" demo.launch()