Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm.notebook import tqdm | |
| import matplotlib.pyplot as plt | |
| from skimage.color import rgb2lab, lab2rgb | |
| import torch | |
| from torch import nn, optim | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from torch.utils.data import Dataset, DataLoader | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| import requests | |
| import gdown | |
| SIZE = 256 | |
| def download_from_drive(url , output): | |
| try: | |
| gdown.download(url, output, quiet=False) | |
| return True | |
| except: | |
| print("Error Occured in Downloading model from Gdrive") | |
| return False | |
| class AverageMeter: | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.count, self.avg, self.sum = [0.0] * 3 | |
| def update(self, val, count=1): | |
| self.count += count | |
| self.sum += count * val | |
| self.avg = self.sum / self.count | |
| def create_loss_meters(): | |
| loss_D_fake = AverageMeter() | |
| loss_D_real = AverageMeter() | |
| loss_D = AverageMeter() | |
| loss_G_GAN = AverageMeter() | |
| loss_G_L1 = AverageMeter() | |
| loss_G = AverageMeter() | |
| return { | |
| "loss_D_fake": loss_D_fake, | |
| "loss_D_real": loss_D_real, | |
| "loss_D": loss_D, | |
| "loss_G_GAN": loss_G_GAN, | |
| "loss_G_L1": loss_G_L1, | |
| "loss_G": loss_G, | |
| } | |
| def update_losses(model, loss_meter_dict, count): | |
| for loss_name, loss_meter in loss_meter_dict.items(): | |
| loss = getattr(model, loss_name) | |
| loss_meter.update(loss.item(), count=count) | |
| def lab_to_rgb(L, ab): | |
| """ | |
| Takes a batch of images | |
| """ | |
| L = (L + 1.0) * 50.0 | |
| ab = ab * 110.0 | |
| Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() | |
| rgb_imgs = [] | |
| for img in Lab: | |
| img_rgb = lab2rgb(img) | |
| rgb_imgs.append(img_rgb) | |
| return np.stack(rgb_imgs, axis=0) | |
| def visualize(model, data, save=True): | |
| model.net_G.eval() | |
| with torch.no_grad(): | |
| model.setup_input(data) | |
| model.forward() | |
| model.net_G.train() | |
| fake_color = model.fake_color.detach() | |
| real_color = model.ab | |
| L = model.L | |
| fake_imgs = lab_to_rgb(L, fake_color) | |
| real_imgs = lab_to_rgb(L, real_color) | |
| fig = plt.figure(figsize=(15, 8)) | |
| for i in range(5): | |
| ax = plt.subplot(3, 5, i + 1) | |
| ax.imshow(L[i][0].cpu(), cmap="gray") | |
| ax.axis("off") | |
| ax = plt.subplot(3, 5, i + 1 + 5) | |
| ax.imshow(fake_imgs[i]) | |
| ax.axis("off") | |
| ax = plt.subplot(3, 5, i + 1 + 10) | |
| ax.imshow(real_imgs[i]) | |
| ax.axis("off") | |
| plt.show() | |
| if save: | |
| fig.savefig(f"colorization_{time.time()}.png") | |
| def log_results(loss_meter_dict): | |
| for loss_name, loss_meter in loss_meter_dict.items(): | |
| print(f"{loss_name}: {loss_meter.avg:.5f}") | |
| def create_lab_tensors(image): | |
| """ | |
| This function receives an image path or a direct image input and creates a dictionary of L and ab tensors. | |
| Args: | |
| - image: either a path to the image file or a direct image input. | |
| Returns: | |
| - lab_dict: dictionary containing the L and ab tensors. | |
| """ | |
| if isinstance(image, str): | |
| # Open the image and convert it to RGB format | |
| img = Image.open(image).convert("RGB") | |
| else: | |
| if isinstance(image, np.ndarray): | |
| img = Image.fromarray(image) | |
| else: | |
| img = image | |
| img = img.convert("RGB") | |
| custom_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((SIZE, SIZE), Image.BICUBIC), | |
| transforms.RandomHorizontalFlip(), # A little data augmentation! | |
| ] | |
| ) | |
| img = custom_transforms(img) | |
| img = np.array(img) | |
| img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b | |
| img_lab = transforms.ToTensor()(img_lab) | |
| L = img_lab[[0], ...] / 50.0 - 1.0 # Between -1 and 1 | |
| L = L.unsqueeze(0) | |
| ab = img_lab[[1, 2], ...] / 110.0 # Between -1 and 1 | |
| return {"L": L, "ab": ab} | |
| def predict_and_visualize_single_image(model, data, save=True): | |
| model.net_G.eval() | |
| with torch.no_grad(): | |
| model.setup_input(data) | |
| model.forward() | |
| fake_color = model.fake_color.detach() | |
| L = model.L | |
| fake_imgs = lab_to_rgb(L, fake_color) | |
| fig, axs = plt.subplots(1, 2, figsize=(8, 4)) | |
| axs[0].imshow(L[0][0].cpu(), cmap="gray") | |
| axs[0].set_title("Grey Image") | |
| axs[0].axis("off") | |
| axs[1].imshow(fake_imgs[0]) | |
| axs[1].set_title("Colored Image") | |
| axs[1].axis("off") | |
| plt.show() | |
| if save: | |
| fig.savefig(f"colorization_{time.time()}.png") | |
| def predict_color(model, image, save=False): | |
| """ | |
| This function receives an image path or a direct image input and creates a dictionary of L and ab tensors. | |
| Args: | |
| - model : Pytorch Gray Scale to Colorization Model | |
| - image: either a path to the image file or a direct image input. | |
| """ | |
| data = create_lab_tensors(image) | |
| predict_and_visualize_single_image(model, data, save) | |
| def load_model_with_cpu(model_class, file_path): | |
| """ | |
| Load PyTorch model from file. | |
| Args: | |
| model_class (torch.nn.Module): PyTorch model class to load. | |
| file_path (str): File path to load the model from. | |
| Returns: | |
| model (torch.nn.Module): Loaded PyTorch model. | |
| """ | |
| model = model_class() | |
| model.load_state_dict(torch.load(file_path, map_location=torch.device("cpu"))) | |
| return model | |
| def load_model_with_gpu(model_class, file_path): | |
| """ | |
| Load PyTorch model from file. | |
| Args: | |
| model_class (torch.nn.Module): PyTorch model class to load. | |
| file_path (str): File path to load the model from. | |
| Returns: | |
| model (torch.nn.Module): Loaded PyTorch model. | |
| """ | |
| model = model_class() | |
| model.load_state_dict(torch.load(file_path)) | |
| return model | |