import pickle import json import os import gradio as gr from PIL import Image import numpy as np import jax from gradio_dualvision import DualVisionApp from gradio_dualvision.gradio_patches.radio import Radio from huggingface_hub import hf_hub_download from model import build_thera from super_resolve import process REPO_ID_EDSR = "prs-eth/thera-edsr-pro" REPO_ID_RDN = "prs-eth/thera-rdn-pro" MAX_SIZE = int(os.getenv('THERA_DEMO_CROP', 10_000)) print(f"JAX devices: {jax.devices()}") print(f"JAX device type: {jax.devices()[0].device_kind}") model_path = hf_hub_download(repo_id=REPO_ID_EDSR, filename="model.pkl") with open(model_path, 'rb') as fh: check = pickle.load(fh) params_edsr, backbone, size = check['model'], check['backbone'], check['size'] model_edsr = build_thera(3, backbone, size) model_path = hf_hub_download(repo_id=REPO_ID_RDN, filename="model.pkl") with open(model_path, 'rb') as fh: check = pickle.load(fh) params_rdn, backbone, size = check['model'], check['backbone'], check['size'] model_rdn = build_thera(3, backbone, size) class TheraApp(DualVisionApp): DEFAULT_SCALE = 3.92 DEFAULT_DO_ENSEMBLE = False DEFAULT_MODEL = 'edsr' def make_header(self): gr.Markdown( """ ## Thera: Aliasing-Free Arbitrary-Scale Super-Resolution with Neural Heat Fields

badge-github-stars

Upload a photo or select an example below to do arbitrary-scale super-resolution in real time!

Note: The model has not been trained on input images with JPEG artifacts, so this does not work well.

Also note: Due to limited viewport size in the browser, the effect is best visible for smaller inputs (e.g. 150x150 px).
For larger inputs, it makes sense to zoom in or download the result and compare locally. We're working on a better solution for visualization.

""" ) def build_user_components(self): with gr.Row(): scale = gr.Slider( label="Scaling factor", minimum=1, maximum=6, step=0.01, value=self.DEFAULT_SCALE, ) model = gr.Radio( [ ("EDSR", 'edsr'), ("RDN", 'rdn'), ], label="Backbone", value=self.DEFAULT_MODEL, ) do_ensemble = gr.Radio( [ ("No", False), ("Yes", True), ], label="Do Ensemble", value=self.DEFAULT_DO_ENSEMBLE, ) return { "scale": scale, "model": model, "do_ensemble": do_ensemble, } def process(self, image_in: Image.Image, **kwargs): scale = kwargs.get("scale", self.DEFAULT_SCALE) do_ensemble = kwargs.get("do_ensemble", self.DEFAULT_DO_ENSEMBLE) model = kwargs.get("model", self.DEFAULT_MODEL) if max(*image_in.size) > MAX_SIZE: gr.Warning(f"The image has been cropped for better visibility, and to enable a smooth experience for all users.") width, height = image_in.size crop_width = min(width, MAX_SIZE) crop_height = min(height, MAX_SIZE) left = (width - crop_width) / 2 top = (height - crop_height) / 2 right = left + crop_width bottom = top + crop_height image_in = image_in.crop((left, top, right, bottom)) source = np.asarray(image_in) / 255. # determine target shape target_shape = ( round(source.shape[0] * scale), round(source.shape[1] * scale), ) if model == 'edsr': m, p = model_edsr, params_edsr elif model == 'rdn': m, p = model_rdn, params_rdn else: raise NotImplementedError('model:', model) out = process(source, m, p, target_shape, do_ensemble=do_ensemble) out = Image.fromarray(np.asarray(out)) nearest = image_in.resize(out.size, Image.NEAREST) out_modalities = { "nearest": nearest, "out": out, } out_settings = { 'scale': scale, 'model': model, 'do_ensemble': do_ensemble, } return out_modalities, out_settings def process_components( self, image_in, modality_selector_left, modality_selector_right, **kwargs ): if image_in is None: raise gr.Error("Input image is required") image_settings = {} if isinstance(image_in, str): image_settings_path = image_in + ".settings.json" if os.path.isfile(image_settings_path): with open(image_settings_path, "r") as f: image_settings = json.load(f) image_in = Image.open(image_in).convert("RGB") else: if not isinstance(image_in, Image.Image): raise gr.Error(f"Input must be a PIL image, got {type(image_in)}") image_in = image_in.convert("RGB") image_settings.update(kwargs) results_dict, results_settings = self.process(image_in, **image_settings) if not isinstance(results_dict, dict): raise gr.Error( f"`process` must return a dict[str, PIL.Image]. Got type: {type(results_dict)}" ) if len(results_dict) == 0: raise gr.Error("`process` did not return any modalities") for k, v in results_dict.items(): if not isinstance(k, str): raise gr.Error( f"Output dict must have string keys. Found key of type {type(k)}: {repr(k)}" ) if k == self.key_original_image: raise gr.Error( f"Output dict must not have an '{self.key_original_image}' key; it is reserved for the input" ) if not isinstance(v, Image.Image): raise gr.Error( f"Value for key '{k}' must be a PIL Image, got type {type(v)}" ) if len(results_settings) != len(self.input_keys): raise gr.Error( f"Expected number of settings ({len(self.input_keys)}), returned ({len(results_settings)})" ) if any(k not in results_settings for k in self.input_keys): raise gr.Error(f"Mismatching setgings keys") results_settings = { k: cls(**ctor_args, value=results_settings[k]) for k, cls, ctor_args in zip( self.input_keys, self.input_cls, self.input_kwargs ) } results_dict = { **results_dict, self.key_original_image: image_in, } results_state = [[v, k] for k, v in results_dict.items()] modalities = list(results_dict.keys()) modality_left = ( modality_selector_left if modality_selector_left in modalities else modalities[0] ) modality_right = ( modality_selector_right if modality_selector_right in modalities else modalities[1] ) return [ results_state, # goes to a gr.Gallery [ results_dict[modality_left], results_dict[modality_right], ], # ImageSliderPlus Radio( choices=modalities, value=modality_left, label="Left", key="Left", ), Radio( choices=modalities if self.left_selector_visible else modalities[1:], value=modality_right, label="Right", key="Right", ), *results_settings.values(), ] with TheraApp( title="Thera Arbitrary-Scale Super-Resolution", examples_path="files", examples_per_page=12, squeeze_canvas=True, advanced_settings_can_be_half_width=False, #spaces_zero_gpu_enabled=True, ) as demo: demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, )