import torch import gradio as gr from torchvision import transforms from PIL import Image import numpy as np from utils.utils import load_restore_ckpt, load_embedder_ckpt import os from gradio_imageslider import ImageSlider # Enforce CPU usage device = torch.device("cuda" if torch.cuda.is_available() else "cpu") embedder_model_path = "ckpts/embedder_model.tar" # Update with actual path to embedder checkpoint restorer_model_path = "ckpts/onerestore_cdd-11.tar" # Update with actual path to restorer checkpoint # Load models on CPU only embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=embedder_model_path) restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=restorer_model_path) # Define image preprocessing and postprocessing transform_resize = transforms.Compose([ transforms.Resize([224,224]), transforms.ToTensor() ]) def postprocess_image(tensor): image = tensor.squeeze(0).cpu().detach().numpy() image = (image) * 255 # Assuming output in [-1, 1], rescale to [0, 255] image = np.clip(image, 0, 255).astype("uint8") # Clip values to [0, 255] return Image.fromarray(image.transpose(1, 2, 0)) # Reorder to (H, W, C) # Define the enhancement function def enhance_image(image, degradation_type=None): # Preprocess the image input_tensor = torch.Tensor((np.array(image)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") lq_em = transform_resize(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") # Generate embedding if degradation_type == "auto" or degradation_type is None: text_embedding, _, [text] = embedder(lq_em, 'image_encoder') else: text_embedding, _, [text] = embedder([degradation_type], 'text_encoder') # Model inference with torch.no_grad(): enhanced_tensor = restorer(input_tensor, text_embedding) # Postprocess the output return (image, postprocess_image(enhanced_tensor)), text # Define the Gradio interface def inference(image, degradation_type=None): return enhance_image(image, degradation_type) #### Image,Prompts examples examples = [ ['image/low_haze_rain_00469_01_lq.png'], ['image/low_haze_snow_00337_01_lq.png'], ] # Create the Gradio app interface using updated API interface = gr.Interface( fn=inference, inputs=[ gr.Image(type="pil", value="image/low_haze_rain_00469_01_lq.png"), # Image input gr.Dropdown(['auto', 'low', 'haze', 'rain', 'snow',\ 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\ 'haze_snow', 'low_haze_rain', 'low_haze_snow'], label="Degradation Type", value="auto") # Manual or auto degradation ], outputs=[ ImageSlider(label="Restored Image", type="pil", show_download_button=True, ), # Enhanced image outputImageSlider(type="pil", show_download_button=True, ), gr.Textbox(label="Degradation Type") # Display the estimated degradation type ], title="Image Restoration with OneRestore", description="Upload an image and enhance it using OneRestore model. You can choose to let the model automatically estimate the degradation type or set it manually.", examples=examples, ) # Launch the app if __name__ == "__main__": interface.launch()