Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| from siren import SIREN | |
| from utils import ( | |
| get_image_coordinates, | |
| image_to_tensor, | |
| tensor_to_image, | |
| downsample_image, | |
| train_siren, | |
| compute_psnr, | |
| compute_mae, | |
| compute_ssim_simple, | |
| get_model_cache_path, | |
| save_model, | |
| load_model | |
| ) | |
| def super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache=True, image_name="uploaded"): | |
| """Perform super-resolution using SIREN. | |
| Args: | |
| input_image: PIL Image (high-res ground truth) | |
| scale_factor: Upscaling factor (2, 4, or 8) | |
| training_steps: Number of training steps | |
| hidden_features: Number of hidden units | |
| hidden_layers: Number of hidden layers | |
| use_cache: Whether to use cached models | |
| image_name: Name for cache identification | |
| Returns: | |
| Tuple of images and metrics | |
| """ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Get original (ground truth) dimensions | |
| gt_image = input_image | |
| W_gt, H_gt = gt_image.size | |
| # Downsample the image | |
| downsampled_image = downsample_image(gt_image, scale_factor) | |
| W_low, H_low = downsampled_image.size | |
| print(f"Ground truth size: {W_gt}x{H_gt}") | |
| print(f"Downsampled size: {W_low}x{H_low}") | |
| print(f"Target upscale: {scale_factor}x") | |
| # Convert downsampled image to tensor | |
| low_res_pixels = image_to_tensor(downsampled_image) | |
| low_res_coords = get_image_coordinates(H_low, W_low) | |
| # Check cache | |
| cache_path = get_model_cache_path( | |
| f"{image_name}_{W_gt}x{H_gt}", | |
| scale_factor, | |
| training_steps, | |
| hidden_features, | |
| hidden_layers | |
| ) | |
| # Create SIREN model | |
| model = SIREN( | |
| in_features=2, | |
| hidden_features=hidden_features, | |
| hidden_layers=hidden_layers, | |
| out_features=3, | |
| outermost_linear=True, | |
| first_omega_0=30, | |
| hidden_omega_0=30 | |
| ) | |
| # Try to load from cache | |
| losses = [] | |
| if use_cache: | |
| loaded_model = load_model(model, cache_path) | |
| if loaded_model is not None: | |
| model = loaded_model | |
| print("Using cached model!") | |
| # Generate dummy loss curve | |
| losses = [0.01] * training_steps | |
| # Train if not loaded from cache | |
| if not losses: | |
| print("Training SIREN model...") | |
| model, losses = train_siren( | |
| model=model, | |
| coords=low_res_coords, | |
| pixels=low_res_pixels, | |
| num_steps=training_steps, | |
| learning_rate=1e-4, | |
| device=device | |
| ) | |
| print("Training complete!") | |
| # Save to cache | |
| if use_cache: | |
| save_model(model, cache_path) | |
| # Generate super-resolved image at original resolution | |
| model.eval() | |
| with torch.no_grad(): | |
| high_res_coords = get_image_coordinates(H_gt, W_gt).to(device) | |
| super_resolved_pixels = model(high_res_coords) | |
| # Convert to image | |
| super_resolved_image = tensor_to_image(super_resolved_pixels, H_gt, W_gt) | |
| # Compute quality metrics | |
| gt_pixels = image_to_tensor(gt_image) | |
| psnr = compute_psnr(super_resolved_pixels.cpu(), gt_pixels) | |
| mae = compute_mae(super_resolved_pixels.cpu(), gt_pixels) | |
| ssim = compute_ssim_simple(super_resolved_pixels.cpu(), gt_pixels) | |
| print(f"\nQuality Metrics:") | |
| print(f" PSNR: {psnr:.2f} dB") | |
| print(f" SSIM: {ssim:.4f}") | |
| print(f" MAE: {mae:.4f}") | |
| # Create metrics display | |
| metrics_text = f""" | |
| π Quality Metrics (vs Ground Truth): | |
| β’ PSNR: {psnr:.2f} dB (higher is better, >30 dB is good) | |
| β’ SSIM: {ssim:.4f} (closer to 1.0 is better) | |
| β’ MAE: {mae:.4f} (lower is better) | |
| Training completed in {training_steps} steps | |
| Final MSE Loss: {losses[-1]:.6f} | |
| """ | |
| # Create loss plot | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.plot(losses, linewidth=2, color='#2E86AB') | |
| ax.set_xlabel('Training Step', fontsize=10) | |
| ax.set_ylabel('MSE Loss', fontsize=10) | |
| ax.set_title('Training Loss Curve', fontsize=12, fontweight='bold') | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| ax.set_facecolor('#f8f9fa') | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=100, facecolor='white') | |
| buf.seek(0) | |
| loss_plot = Image.open(buf) | |
| plt.close() | |
| # Return individual images and metrics | |
| # Order: downsampled, loss_plot, super_resolved, gt, metrics (matches UI layout) | |
| return downsampled_image, loss_plot, super_resolved_image, gt_image, metrics_text | |
| # Create Gradio interface | |
| with gr.Blocks(title="SIREN Super-Resolution") as demo: | |
| gr.Markdown( | |
| """ | |
| # π₯ SIREN Super-Resolution Demo | |
| Upload a high-resolution image, and watch **SIREN** (Sinusoidal Representation Networks) | |
| learn to super-resolve it from an artificially downsampled version. | |
| **How it works:** Your image is downsampled β SIREN learns the low-res β Generates high-res β Compare with original! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input") | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload High-Resolution Image", | |
| height=300 | |
| ) | |
| scale_factor = gr.Radio( | |
| choices=[2, 4, 8], | |
| value=2, | |
| label="Downsampling Scale Factor", | |
| info="Higher scale = harder task" | |
| ) | |
| training_steps = gr.Dropdown( | |
| choices=[500, 1000, 1500, 2000, 3000, 4000, 5000], | |
| value=2000, | |
| label="Training Epochs/Steps", | |
| info="More steps = better quality but slower" | |
| ) | |
| use_cache = gr.Checkbox( | |
| value=True, | |
| label="Use Model Cache", | |
| info="Save/load trained models to avoid retraining" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| hidden_features = gr.Slider( | |
| minimum=128, | |
| maximum=512, | |
| value=256, | |
| step=64, | |
| label="Hidden Features", | |
| info="Network width" | |
| ) | |
| hidden_layers = gr.Slider( | |
| minimum=2, | |
| maximum=6, | |
| value=3, | |
| step=1, | |
| label="Hidden Layers", | |
| info="Network depth" | |
| ) | |
| run_btn = gr.Button("π Run Super-Resolution", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π Results & Comparison") | |
| with gr.Tabs(): | |
| with gr.Tab("π Side-by-Side Comparison"): | |
| gr.Markdown("**Low-Resolution Input & Training**") | |
| with gr.Row(): | |
| output_downsampled = gr.Image( | |
| label="Downsampled (Input)", | |
| type="pil", | |
| height=300 | |
| ) | |
| output_loss_plot = gr.Image( | |
| label="Training Loss Curve", | |
| type="pil", | |
| height=300 | |
| ) | |
| gr.Markdown("**High-Resolution Comparison**") | |
| with gr.Row(): | |
| output_super_resolved = gr.Image( | |
| label="Super-Resolved (SIREN Prediction)", | |
| type="pil", | |
| height=300 | |
| ) | |
| output_ground_truth = gr.Image( | |
| label="Ground Truth (Original)", | |
| type="pil", | |
| height=300 | |
| ) | |
| with gr.Tab("π Quality Metrics"): | |
| metrics_display = gr.Textbox( | |
| label="Quality Analysis", | |
| lines=10, | |
| max_lines=15 | |
| ) | |
| # Examples | |
| gr.Markdown("### πΈ Try these examples:") | |
| # Wrapper function to handle examples with image names | |
| def super_resolve_with_name(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache): | |
| # Extract image name from the example path if it's from samples | |
| image_name = "uploaded" | |
| if hasattr(input_image, 'name') and input_image.name: | |
| image_name = input_image.name.split('/')[-1].split('.')[0] | |
| return super_resolve_image(input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache, image_name) | |
| gr.Examples( | |
| examples=[ | |
| ["samples/cat.jpg", 2, 2000, 256, 3, True], | |
| ["samples/landscape.jpg", 4, 3000, 256, 3, True], | |
| ["samples/portrait.jpg", 2, 2000, 256, 3, True], | |
| ["samples/flower.jpg", 4, 3000, 256, 4, True], | |
| ], | |
| inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache], | |
| outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display], | |
| fn=super_resolve_with_name, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### π About SIREN & Metrics | |
| **SIREN** uses sine activation functions for representing continuous signals with fine details. | |
| **Quality Metrics Explained:** | |
| - **PSNR** (Peak Signal-to-Noise Ratio): Measures reconstruction quality. >30 dB is good, >40 dB is excellent. | |
| - **SSIM** (Structural Similarity Index): Perceptual quality metric. 1.0 is perfect, >0.9 is very good. | |
| - **MAE** (Mean Absolute Error): Average pixel difference. Lower is better. | |
| **Tips for Better Results:** | |
| - Start with 2x scale for quick testing | |
| - Use 3000-5000 steps for 4x and 8x scaling | |
| - Enable model cache to avoid retraining identical settings | |
| - Higher scale factors need more training steps and network capacity | |
| **Reference:** [SIREN Paper](https://arxiv.org/abs/2006.09661) | | |
| [Tutorial](https://github.com/nipunbatra/pml-teaching/blob/master/notebooks/siren.ipynb) | |
| """ | |
| ) | |
| # Connect the button | |
| run_btn.click( | |
| fn=super_resolve_with_name, | |
| inputs=[input_image, scale_factor, training_steps, hidden_features, hidden_layers, use_cache], | |
| outputs=[output_downsampled, output_loss_plot, output_super_resolved, output_ground_truth, metrics_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |