Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Hack for spaces | |
| import os | |
| os.system("pip uninstall -y gradio") | |
| os.system("pip install -r requirements.txt") | |
| # Real code begins | |
| from typing import Union, List | |
| import gradio as gr | |
| import matplotlib | |
| import torch | |
| from pytorch_lightning.utilities.types import EPOCH_OUTPUT | |
| matplotlib.use("Agg") | |
| import numpy as np | |
| from PIL import Image | |
| import albumentations as A | |
| import albumentations.pytorch as al_pytorch | |
| import torchvision | |
| from pl_bolts.models.gans import Pix2Pix | |
| from pl_bolts.models.gans.pix2pix.components import PatchGAN | |
| import torchvision.models as models | |
| """ Class """ | |
| class OverpoweredPix2Pix(Pix2Pix): | |
| def validation_step(self, batch, batch_idx): | |
| """Validation step""" | |
| real, condition = batch | |
| with torch.no_grad(): | |
| loss = self._disc_step(real, condition) | |
| self.log("val_PatchGAN_loss", loss) | |
| loss = self._gen_step(real, condition) | |
| self.log("val_generator_loss", loss) | |
| return {"sketch": real, "colour": condition} | |
| def validation_epoch_end( | |
| self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] | |
| ) -> None: | |
| sketch = outputs[0]["sketch"] | |
| colour = outputs[0]["colour"] | |
| with torch.no_grad(): | |
| gen_coloured = self.gen(sketch) | |
| grid_image = torchvision.utils.make_grid( | |
| [ | |
| sketch[0], | |
| colour[0], | |
| gen_coloured[0], | |
| ], | |
| normalize=True, | |
| ) | |
| self.logger.experiment.add_image( | |
| f"Image Grid {str(self.current_epoch)}", grid_image, self.current_epoch | |
| ) | |
| class PatchGanChanged(OverpoweredPix2Pix): | |
| def __init__(self, in_channels, out_channels): | |
| super(PatchGanChanged, self).__init__( | |
| in_channels=in_channels, out_channels=out_channels | |
| ) | |
| self.patch_gan = self.get_dense_PatchGAN(self.patch_gan) | |
| def get_dense_PatchGAN(disc: PatchGAN) -> PatchGAN: | |
| """Add final layer to gan""" | |
| disc.final = torch.nn.Sequential( | |
| disc.final, | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(16 * 16, 1), | |
| ) | |
| return disc | |
| """ Load the model """ | |
| # train_64_val_16_patchgan_1val_plbolts_model_chkpt = "model/lightning_bolts_model/modified_path_gan.ckpt" | |
| train_64_val_16_plbolts_model_chkpt = ( | |
| "model/lightning_bolts_model/epoch=99-step=44600.ckpt" | |
| ) | |
| train_16_val_1_plbolts_model_chkpt = ( | |
| "model/lightning_bolts_model/epoch=99-step=89000.ckpt" | |
| ) | |
| modified_patchgan_chkpt = "model/lightning_bolts_model/modified_patchgan.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/version_0/checkpoints/epoch=199-step=355600.ckpt" | |
| # model_checkpoint_path = "model/pix2pix_lightning_model/gen.pth" | |
| # Load the models | |
| train_64_val_16_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
| train_64_val_16_plbolts_model_chkpt | |
| ) | |
| train_64_val_16_plbolts_model.eval() | |
| # | |
| train_16_val_1_plbolts_model = OverpoweredPix2Pix.load_from_checkpoint( | |
| train_16_val_1_plbolts_model_chkpt | |
| ) | |
| train_16_val_1_plbolts_model.eval() | |
| # | |
| modified_patchgan_model = PatchGanChanged.load_from_checkpoint(modified_patchgan_chkpt) | |
| modified_patchgan_model.eval() | |
| # Create new class | |
| class OverpoweredPix2Pix(Pix2Pix): | |
| def __init__(self, in_channels, out_channels): | |
| super(OverpoweredPix2Pix, self).__init__( | |
| in_channels=in_channels, out_channels=out_channels | |
| ) | |
| self._create_inception_score() | |
| def _gen_step(self, real_images, conditioned_images): | |
| # Pix2Pix has adversarial and a reconstruction loss | |
| # First calculate the adversarial loss | |
| fake_images = self.gen(conditioned_images) | |
| disc_logits = self.patch_gan(fake_images, conditioned_images) | |
| adversarial_loss = self.adversarial_criterion( | |
| disc_logits, torch.ones_like(disc_logits) | |
| ) | |
| # calculate reconstruction loss | |
| recon_loss = self.recon_criterion(fake_images, real_images) | |
| lambda_recon = self.hparams.lambda_recon | |
| # calculate cosine similarity | |
| representations_real = self.feature_extractor(real_images).flatten(1) | |
| representations_fake = self.feature_extractor(fake_images).flatten(1) | |
| similarity_score_list = self.cosine_similarity( | |
| representations_real, representations_fake | |
| ) | |
| cosine_sim = sum(similarity_score_list) / len(similarity_score_list) | |
| self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy()) | |
| # print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, ) | |
| return ( | |
| (adversarial_loss) | |
| + (lambda_recon * recon_loss) | |
| + (lambda_recon * (1 - cosine_sim)) | |
| ) | |
| def _create_inception_score(self): | |
| # init a pretrained resnet | |
| backbone = models.resnet50(pretrained=True) | |
| num_filters = backbone.fc.in_features | |
| layers = list(backbone.children())[:-1] | |
| self.feature_extractor = torch.nn.Sequential(*layers) | |
| self.cosine_similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6) | |
| def validation_step(self, batch, batch_idx): | |
| """Validation step""" | |
| real, condition = batch | |
| with torch.no_grad(): | |
| disc_loss = self._disc_step(real, condition) | |
| self.log("Valid PatchGAN Loss", disc_loss) | |
| gan_loss = self._gen_step(real, condition) | |
| self.log("Valid Generator Loss", gan_loss) | |
| # | |
| fake_images = self.gen(condition) | |
| representations_real = self.feature_extractor(real).flatten(1) | |
| representations_fake = self.feature_extractor(fake_images).flatten(1) | |
| similarity_score_list = self.cosine_similarity( | |
| representations_real, representations_fake | |
| ) | |
| cosine_sim = sum(similarity_score_list) / len(similarity_score_list) | |
| self.log("Valid Cosine Sim", cosine_sim) | |
| return {"sketch": condition, "colour": real} | |
| def validation_epoch_end( | |
| self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] | |
| ) -> None: | |
| sketch = outputs[0]["sketch"] | |
| colour = outputs[0]["colour"] | |
| self.feature_extractor.eval() | |
| with torch.no_grad(): | |
| gen_coloured = self.gen(sketch) | |
| representations_gen = self.feature_extractor(gen_coloured).flatten(1) | |
| representations_fake = self.feature_extractor(colour).flatten(1) | |
| similarity_score_list = self.cosine_similarity( | |
| representations_gen, representations_fake | |
| ) | |
| similarity_score = sum(similarity_score_list) / len(similarity_score_list) | |
| grid_image = torchvision.utils.make_grid( | |
| [ | |
| sketch[0], | |
| colour[0], | |
| gen_coloured[0], | |
| ], | |
| normalize=True, | |
| ) | |
| self.logger.experiment.add_image( | |
| f"Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ", | |
| grid_image, | |
| self.current_epoch, | |
| ) | |
| cosine_sim_model_chkpt = "model/lightning_bolts_model/cosine_sim_model.ckpt" | |
| cosine_sim_model = OverpoweredPix2Pix.load_from_checkpoint(cosine_sim_model_chkpt) | |
| cosine_sim_model.eval() | |
| def predict(img: Image, type_of_model: str): | |
| """Create predictions""" | |
| # transform img | |
| image = np.asarray(img) | |
| # use on inference | |
| inference_transform = A.Compose( | |
| [ | |
| A.Resize(width=256, height=256), | |
| A.Normalize( | |
| mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0 | |
| ), | |
| al_pytorch.ToTensorV2(), | |
| ] | |
| ) | |
| inference_img = inference_transform(image=image)["image"].unsqueeze(0) | |
| # Choose model | |
| if type_of_model == "train batch size 16, val batch size 1": | |
| model = train_16_val_1_plbolts_model | |
| elif type_of_model == "train batch size 64, val batch size 16": | |
| model = train_64_val_16_plbolts_model | |
| elif type_of_model == "cosine similarity": | |
| model = cosine_sim_model | |
| else: | |
| model = modified_patchgan_model | |
| with torch.no_grad(): | |
| result = model.gen(inference_img) | |
| torchvision.utils.save_image(result, "inference_image.png", normalize=True) | |
| return "inference_image.png" # 'coloured_image.png', | |
| def predict1(img: Image): | |
| return predict(img=img, type_of_model="train batch size 16, val batch size 1") | |
| def predict2(img: Image): | |
| return predict(img=img, type_of_model="train batch size 64, val batch size 16") | |
| def predict3(img: Image): | |
| return predict( | |
| img=img, | |
| type_of_model="train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", | |
| ) | |
| def predict4(img: Image): | |
| return predict( | |
| img=img, | |
| type_of_model="cosine similarity", | |
| ) | |
| model_input = gr.inputs.Radio( | |
| [ | |
| "train batch size 16, val batch size 1", | |
| "train batch size 64, val batch size 16", | |
| "train batch size 64, val batch size 16, patch gan has 1 output score instead of 16*16", | |
| ], | |
| label="Type of Pix2Pix model to use : ", | |
| ) | |
| image_input = gr.inputs.Image(type="pil") | |
| img_examples = [ | |
| "examples/thesis_test.png", | |
| "examples/thesis_test2.png", | |
| "examples/thesis1.png", | |
| "examples/thesis4.png", | |
| "examples/thesis5.png", | |
| "examples/thesis6.png", | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(" # Colour your sketches!") | |
| gr.Markdown(" ## Description :") | |
| gr.Markdown(" There are 4 Pix2Pix models in this example:") | |
| gr.Markdown(" 1. Training batch size is 16 , validation is 1") | |
| gr.Markdown(" 2. Training batch size is 64 , validation is 16") | |
| gr.Markdown( | |
| " 3. PatchGAN is changed, 1 value only instead of 16*16 ;" | |
| "training batch size is 64 , validation is 16" | |
| ) | |
| gr.Markdown( | |
| " 4. cosine similarity is also added as a metric in this experiment for the generator. " | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("tr_16_val_1"): | |
| with gr.Row(): | |
| image_input1 = gr.inputs.Image(type="pil") | |
| image_output1 = gr.outputs.Image( | |
| type="pil", | |
| ) | |
| colour_1 = gr.Button("Colour it!") | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input1, | |
| outputs=image_output1, | |
| fn=predict1, | |
| ) | |
| with gr.TabItem("tr_64_val_14"): | |
| with gr.Row(): | |
| image_input2 = gr.inputs.Image(type="pil") | |
| image_output2 = gr.outputs.Image( | |
| type="pil", | |
| ) | |
| colour_2 = gr.Button("Colour it!") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input2, | |
| outputs=image_output2, | |
| fn=predict2, | |
| ) | |
| with gr.TabItem("Single Value Discriminator"): | |
| with gr.Row(): | |
| image_input3 = gr.inputs.Image(type="pil") | |
| image_output3 = gr.outputs.Image( | |
| type="pil", | |
| ) | |
| colour_3 = gr.Button("Colour it!") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input3, | |
| outputs=image_output3, | |
| fn=predict3, | |
| ) | |
| with gr.TabItem("Cosine similarity loss"): | |
| with gr.Row(): | |
| image_input4 = gr.inputs.Image(type="pil") | |
| image_output4 = gr.outputs.Image( | |
| type="pil", | |
| ) | |
| colour_4 = gr.Button("Colour it!") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=image_input4, | |
| outputs=image_output4, | |
| fn=predict4, | |
| ) | |
| colour_1.click( | |
| fn=predict1, | |
| inputs=image_input1, | |
| outputs=image_output1, | |
| ) | |
| colour_2.click( | |
| fn=predict2, | |
| inputs=image_input2, | |
| outputs=image_output2, | |
| ) | |
| colour_3.click( | |
| fn=predict3, | |
| inputs=image_input3, | |
| outputs=image_output3, | |
| ) | |
| colour_4.click( | |
| fn=predict4, | |
| inputs=image_input4, | |
| outputs=image_output4, | |
| ) | |
| demo.title = "Colour your sketches!" | |
| demo.launch() | |