Refactor image post-processing in gradio_demo_rgb2x.py
Browse files- rgb2x/gradio_demo_rgb2x.py +13 -16
rgb2x/gradio_demo_rgb2x.py
CHANGED
@@ -1,11 +1,11 @@
|
|
|
|
1 |
import spaces
|
2 |
-
import numpy as np
|
3 |
import os
|
4 |
-
from typing import cast
|
5 |
import gradio as gr
|
6 |
from PIL import Image
|
7 |
import torch
|
8 |
import torchvision
|
|
|
9 |
from diffusers import DDIMScheduler
|
10 |
from load_image import load_exr_image, load_ldr_image
|
11 |
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
@@ -36,15 +36,15 @@ def generate(
|
|
36 |
num_samples: int,
|
37 |
) -> list[Image.Image]:
|
38 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
39 |
-
|
40 |
-
if
|
41 |
-
photo = load_exr_image(
|
42 |
elif (
|
43 |
-
|
44 |
-
or
|
45 |
-
or
|
46 |
):
|
47 |
-
photo = load_ldr_image(
|
48 |
|
49 |
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
50 |
old_height = photo.shape[1]
|
@@ -96,10 +96,7 @@ def generate(
|
|
96 |
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
97 |
return_list.append(generated_image)
|
98 |
|
99 |
-
|
100 |
-
return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
|
101 |
-
|
102 |
-
return_list.append(post_process_image(photo, label="Input Image"))
|
103 |
return return_list
|
104 |
|
105 |
|
@@ -149,9 +146,9 @@ with gr.Blocks() as demo:
|
|
149 |
examples = gr.Examples(
|
150 |
examples=[
|
151 |
[
|
152 |
-
"rgb2x/example/Castlereagh_corridor_photo.png",
|
153 |
-
0,
|
154 |
-
50,
|
155 |
1, # Samples
|
156 |
]
|
157 |
],
|
|
|
1 |
+
from typing import cast
|
2 |
import spaces
|
|
|
3 |
import os
|
|
|
4 |
import gradio as gr
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
import torchvision
|
8 |
+
|
9 |
from diffusers import DDIMScheduler
|
10 |
from load_image import load_exr_image, load_ldr_image
|
11 |
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
|
|
36 |
num_samples: int,
|
37 |
) -> list[Image.Image]:
|
38 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
39 |
+
photo_name = photo.name
|
40 |
+
if photo_name.endswith(".exr"):
|
41 |
+
photo = load_exr_image(photo_name, tonemaping=True, clamp=True).to("cuda")
|
42 |
elif (
|
43 |
+
photo_name.endswith(".png")
|
44 |
+
or photo_name.endswith(".jpg")
|
45 |
+
or photo_name.endswith(".jpeg")
|
46 |
):
|
47 |
+
photo = load_ldr_image(photo_name, from_srgb=True).to("cuda")
|
48 |
|
49 |
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
50 |
old_height = photo.shape[1]
|
|
|
96 |
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
97 |
return_list.append(generated_image)
|
98 |
|
99 |
+
return_list.append((photo_name, "Input Image"))
|
|
|
|
|
|
|
100 |
return return_list
|
101 |
|
102 |
|
|
|
146 |
examples = gr.Examples(
|
147 |
examples=[
|
148 |
[
|
149 |
+
"rgb2x/example/Castlereagh_corridor_photo.png", # Photo
|
150 |
+
0, # Seed
|
151 |
+
50, # Inference Step
|
152 |
1, # Samples
|
153 |
]
|
154 |
],
|