import gradio as gr import torch from torchvision import transforms from PIL import Image, ImageEnhance, ImageOps import numpy as np import os from typing import Tuple def add_reflection_padding(image: Image.Image, padding: int) -> Tuple[Image.Image, Tuple[int, int]]: """ Adds reflection padding to the input PIL image. Parameters: image (PIL.Image.Image): The input image. padding (int): The number of pixels for reflection padding. Returns: PIL.Image.Image: The padded image. tuple: Original image size (width, height). """ original_size = image.size width, height = image.size padded_image = Image.new('RGB', (width + 2 * padding, height + 2 * padding)) padded_image.paste(image, (padding, padding)) # Reflect padding left_padding = image.crop((0, 0, padding, height)) left_padding = ImageOps.mirror(left_padding) padded_image.paste(left_padding, (0, padding)) right_padding = image.crop((width - padding, 0, width, height)) right_padding = ImageOps.mirror(right_padding) padded_image.paste(right_padding, (width + padding, padding)) top_padding = image.crop((0, 0, width, padding)) top_padding = ImageOps.flip(top_padding) padded_image.paste(top_padding, (padding, 0)) bottom_padding = image.crop((0, height - padding, width, height)) bottom_padding = ImageOps.flip(bottom_padding) padded_image.paste(bottom_padding, (padding, height + padding)) return padded_image, original_size def center_crop(image: Image.Image, output_size: Tuple[int, int]) -> Image.Image: """ Center-crops the input image to the specified output size. Parameters: image (PIL.Image.Image): The input image. output_size (tuple): Desired output size (width, height). Returns: PIL.Image.Image: Center-cropped image. """ original_width, original_height = image.size left = (original_width - output_size[0]) // 2 top = (original_height - output_size[1]) // 2 right = left + output_size[0] bottom = top + output_size[1] cropped_image = image.crop((left, top, right, bottom)) return cropped_image # Load the TorchScript model model = torch.jit.load('model.pth', map_location=torch.device('cpu')) model.eval() def inference(input_image, amount: float = 1.0, sharpness_factor: float = 1.5): icc_profile = input_image.info.get('icc_profile') input_image, org_size = add_reflection_padding(input_image, 4) org_image = input_image input_image = np.clip(np.array(input_image) / 255.0, 0.0, 1.0).astype('float32') preprocess = transforms.Compose([ transforms.ToTensor() ]) input_image = preprocess(input_image) with torch.no_grad(): output_image = model(input_image.unsqueeze(0), amount) output_image = output_image.squeeze(0) output_image = output_image.cpu().permute(1, 2, 0).numpy() output_image = Image.fromarray(np.uint8(output_image * 255.0)) output_image.info['icc_profile'] = icc_profile output_image = center_crop(output_image, org_size) enhancer = ImageEnhance.Sharpness(output_image) output_image = enhancer.enhance(sharpness_factor) return output_image def read_readme(): with open('README.md', 'r') as file: lines = file.readlines() # Skip lines until the second "---" separator readme_lines = [] skip = 0 for line in lines: if line.startswith("---") and skip < 2: skip += 1 else: if skip >= 2: readme_lines.append(line) return ''.join(readme_lines) example_list = [["examples/" + example] for example in os.listdir("examples")] readme_content = read_readme() # Define the Gradio interface iface = gr.Interface( theme='gradio/soft', title="Dedither CNN App", description="Dedither images bit-reduced to 8, 4 or even 1 bit - check article below examples for more!", article=readme_content, fn=inference, inputs=[ gr.Image(type="pil"), gr.Slider(minimum=0.0, maximum=1.0, value=1.0, label="Dedithering Amount", info="Adjust how much of the predicted dithering should be removed"), gr.Slider(minimum=1.0, maximum=2.0, value=1.5, label="Sharpness Factor", info="An output image may look slightly blurred - 1.0 leaves the image unchanged, >1.0 sharpens the image") ], outputs=gr.Image(type="pil"), examples=example_list ) # Launch the Gradio interface iface.launch()