danifei commited on
Commit
a490245
·
1 Parent(s): 663974d

changing the appearance of the app

Browse files
app.py CHANGED
@@ -3,9 +3,43 @@ from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
  import torch.nn.functional as F
 
 
6
 
7
  from archs import create_model, resume_model
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  PATH_MODEL = './DeMoE.pt'
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model_opt = {
@@ -37,19 +71,36 @@ def pad_tensor(tensor, multiple = 16):
37
 
38
  return tensor
39
 
40
- def process_img(image, task = 'auto'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  tensor = pil_to_tensor(image).unsqueeze(0).to(device)
42
  _, _, H, W = tensor.shape
43
 
44
  tensor = pad_tensor(tensor)
45
 
46
  with torch.no_grad():
47
- output = model(tensor, task)
48
 
49
  output = torch.clamp(output, 0., 1.)
50
  output = output[:,:, :H, :W].squeeze(0)
51
  return tensor_to_pil(output)
52
 
 
 
 
53
  title = 'DeMoE 🌪️​'
54
  description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
55
 
@@ -66,11 +117,11 @@ Available code at [github](https://github.com/cidautai/DeMoE). More information
66
  <br>
67
  '''
68
 
69
- examples = [['examples/1POA1811.png'],
70
- ['examples/12_blur.png'],
71
- ['examples/0031.png'],
72
- ['examples/000143.png'],
73
- ['examples/blur_4.png']]
74
 
75
  css = """
76
  .image-frame img, .image-container img {
@@ -80,19 +131,82 @@ css = """
80
  }
81
  """
82
 
83
- demo = gr.Interface(
84
- fn = process_img,
85
- inputs = [
86
- gr.Image(type = 'pil', label = 'input')
87
- ],
88
- outputs = [gr.Image(type='pil', label = 'output')],
89
- title = title,
90
- description = description,
91
- examples = examples,
92
- cache_examples=False,
93
- css = css
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if __name__ == '__main__':
98
- demo.launch(show_error = True, ssr_mode=False)
 
3
  import torch
4
  import torchvision.transforms as transforms
5
  import torch.nn.functional as F
6
+ import os
7
+ import glob
8
 
9
  from archs import create_model, resume_model
10
 
11
+ # -------- Detect folders & images (assets/<folder>) --------
12
+ IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
13
+
14
+ def list_subfolders(base="assets"):
15
+ """Return a sorted list of immediate subfolders inside base."""
16
+ if not os.path.isdir(base):
17
+ return []
18
+ subs = [d for d in sorted(os.listdir(base)) if os.path.isdir(os.path.join(base, d))]
19
+ return subs
20
+
21
+ def list_images(folder):
22
+ """Return full paths of images inside assets/<folder>."""
23
+ paths = sorted(glob.glob(os.path.join("assets", folder, "*")))
24
+ return [p for p in paths if p.lower().endswith(IMG_EXTS)]
25
+
26
+ # -------- Folder/Gallery interactions --------
27
+ def update_gallery(folder):
28
+ """Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
29
+ files = list_images(folder)
30
+ return gr.update(value=files, visible=True), files
31
+
32
+ def load_from_gallery(evt: gr.SelectData, current_files):
33
+ """On gallery click, load the clicked image path into the input image."""
34
+ idx = evt.index
35
+ if not current_files or idx is None or idx >= len(current_files):
36
+ return gr.update()
37
+ path = current_files[idx]
38
+ return Image.open(path)
39
+
40
+
41
+ # Model
42
+
43
  PATH_MODEL = './DeMoE.pt'
44
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
  model_opt = {
 
71
 
72
  return tensor
73
 
74
+ TASK_LABELS = ["Auto", "Defocus", "Low-Light", "Global-Motion", "Synth-Global-Motion", "Local-Motion"]
75
+
76
+ # Map pretty label -> internal task code used by the model
77
+ LABEL_TO_TASK = {
78
+ "Auto": "auto",
79
+ "Low-Light": "low_light",
80
+ "Global-Motion": "global_motion",
81
+ "Defocus": "defocus",
82
+ "Synth-Global-Motion": "synth_global_motion",
83
+ "Local-Motion": "local_motion",
84
+ }
85
+
86
+ def process_img(image, task_label = 'auto'):
87
+ """Main inference: converts PIL -> tensor, pads, runs the model with selected task, clamps, crops, returns PIL."""
88
+ task_label = LABEL_TO_TASK.get(task_label, 'auto')
89
  tensor = pil_to_tensor(image).unsqueeze(0).to(device)
90
  _, _, H, W = tensor.shape
91
 
92
  tensor = pad_tensor(tensor)
93
 
94
  with torch.no_grad():
95
+ output = model(tensor, task_label)
96
 
97
  output = torch.clamp(output, 0., 1.)
98
  output = output[:,:, :H, :W].squeeze(0)
99
  return tensor_to_pil(output)
100
 
101
+
102
+
103
+
104
  title = 'DeMoE 🌪️​'
105
  description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
106
 
 
117
  <br>
118
  '''
119
 
120
+ # examples = [['examples/1POA1811.png'],
121
+ # ['examples/12_blur.png'],
122
+ # ['examples/0031.png'],
123
+ # ['examples/000143.png'],
124
+ # ['examples/blur_4.png']]
125
 
126
  css = """
127
  .image-frame img, .image-container img {
 
131
  }
132
  """
133
 
134
+
135
+ # Example lists per folder under ./assets (kept simple, no helpers)
136
+ exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
137
+
138
+ def list_basenames(folder):
139
+ """Return [[basename, task_label], ...] for gr.Examples using examples_dir."""
140
+ paths = sorted(glob.glob(f"examples/{folder}/*"))
141
+ basenames = [os.path.basename(p) for p in paths if p.lower().endswith(exts)]
142
+ # Default task per folder (tweak as you like)
143
+ default_task = "auto"
144
+ return [[name, default_task] for name in basenames]
145
+
146
+
147
+ examples_low_light = list_basenames("low_light")
148
+ examples_global_motion = list_basenames("global_motion")
149
+ examples_synth_global_motion = list_basenames("synth_global_motion")
150
+ examples_local_motion = list_basenames("local_motion")
151
+ examples_defocus = list_basenames("defocus")
152
+
153
+ # -----------------------------
154
+ # Gradio Blocks layout
155
+ # -----------------------------
156
+ with gr.Blocks(css=css, title=title) as demo:
157
+ gr.Markdown(f"# {title}\n\n{description}")
158
+
159
+ with gr.Row():
160
+ # Input image and the task selector (Radio)
161
+ inp_img = gr.Image(type='pil', label='input')
162
+ # Output image and action button
163
+ out_img = gr.Image(type='pil', label='output')
164
+ task_selector = gr.Radio(
165
+ choices=TASK_LABELS,
166
+ value="auto",
167
+ label="Blur type"
168
  )
169
 
170
+ btn = gr.Button("Restore", variant="primary")
171
+
172
+ # Connect the button to the inference function
173
+ btn.click(
174
+ fn=process_img,
175
+ inputs=[inp_img, task_selector],
176
+ outputs=[out_img]
177
+ )
178
+
179
+ # Examples grouped by folder (each item loads image + task automatically)
180
+ gr.Markdown("## Examples")
181
+ with gr.Row():
182
+ # List folders found in ./assets
183
+ folders = list_subfolders("examples")
184
+ folder_radio = gr.Radio(choices=folders, label="Examples Folders", interactive=True)
185
+
186
+ gallery = gr.Gallery(
187
+ label="Images from the selected folder",
188
+ visible=False,
189
+ allow_preview=True,
190
+ columns=6,
191
+ height=320,
192
+ )
193
+
194
+ # State holds the current file list shown in the gallery (to resolve clicks)
195
+ current_files_state = gr.State([])
196
+
197
+ # When changing folder -> update gallery and state
198
+ folder_radio.change(
199
+ fn=update_gallery,
200
+ inputs=folder_radio,
201
+ outputs=[gallery, current_files_state]
202
+ )
203
+
204
+ # When clicking a thumbnail -> load it into the input image
205
+ gallery.select(
206
+ fn=load_from_gallery,
207
+ inputs=[current_files_state],
208
+ outputs=inp_img
209
+ )
210
+
211
  if __name__ == '__main__':
212
+ demo.launch(show_error = True, server_name="0.0.0.0", server_port=7864, ssr_mode=False)
demoe_gradio.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ import torch.nn.functional as F
7
+
8
+ from archs import create_model, resume_model
9
+
10
+
11
+ # -------- Detect folders & images (assets/<folder>) --------
12
+ IMG_EXTS = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
13
+
14
+ def list_subfolders(base="assets"):
15
+ """Return a sorted list of immediate subfolders inside base."""
16
+ if not os.path.isdir(base):
17
+ return []
18
+ subs = [d for d in sorted(os.listdir(base)) if os.path.isdir(os.path.join(base, d))]
19
+ return subs
20
+
21
+ def list_images(folder):
22
+ """Return full paths of images inside assets/<folder>."""
23
+ paths = sorted(glob.glob(os.path.join("assets", folder, "*")))
24
+ return [p for p in paths if p.lower().endswith(IMG_EXTS)]
25
+
26
+
27
+ # -------- Folder/Gallery interactions --------
28
+ def update_gallery(folder):
29
+ """Given a folder name, return the gallery items (list of image paths) and store the same list in state."""
30
+ files = list_images(folder)
31
+ return gr.update(value=files, visible=True), files # (gallery_update, state_list)
32
+
33
+ def load_from_gallery(evt: gr.SelectData, current_files):
34
+ """On gallery click, load the clicked image path into the input image."""
35
+ idx = evt.index
36
+ if not current_files or idx is None or idx >= len(current_files):
37
+ return gr.update()
38
+ path = current_files[idx]
39
+ return Image.open(path)
40
+ # -----------------------------
41
+ # Model
42
+ # -----------------------------
43
+ PATH_MODEL = './DeMoE.pt'
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ model_opt = {
46
+ 'name': 'DeMoE',
47
+ 'img_channels': 3,
48
+ 'width': 32,
49
+ 'middle_blk_num': 2,
50
+ 'enc_blk_nums': [2, 2, 2, 2],
51
+ 'dec_blk_nums': [2, 2, 2, 2],
52
+ 'num_experts': 5,
53
+ 'k_used': 1
54
+ }
55
+
56
+ pil_to_tensor = transforms.ToTensor()
57
+ tensor_to_pil = transforms.ToPILImage()
58
+
59
+ # Create and load model weights
60
+ model = create_model(model_opt, device)
61
+ _ = torch.load(PATH_MODEL, map_location=device, weights_only=False) # keep compatibility with different checkpoints
62
+ model = resume_model(model, PATH_MODEL, device)
63
+
64
+ def pad_tensor(tensor, multiple=16):
65
+ """Pad tensor so that H and W are multiples of `multiple` (default 16)."""
66
+ _, _, H, W = tensor.shape
67
+ pad_h = (multiple - H % multiple) % multiple
68
+ pad_w = (multiple - W % multiple) % multiple
69
+ tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=0)
70
+ return tensor
71
+
72
+ # -----------------------------
73
+ # UI / Inference
74
+ # -----------------------------
75
+ title = 'DeMoE 🌪️​'
76
+ description = ''' >**Abstract**: Image deblurring, removing blurring artifacts from images, is a fundamental task in computational photography and low-level computer vision. Existing approaches focus on specialized solutions tailored to particular blur types, thus, these solutions lack generalization. This limitation in current methods implies requiring multiple models to cover several blur types, which is not practical in many real scenarios. In this paper, we introduce the first all-in-one deblurring method capable of efficiently restoring images affected by diverse blur degradations, including global motion, local motion, blur in low-light conditions, and defocus blur. We propose a mixture-of-experts (MoE) decoding module, which dynamically routes image features based on the recognized blur degradation, enabling precise and efficient restoration in an end-to-end manner. Our unified approach not only achieves performance comparable to dedicated task-specific models, but also shows promising generalization to unseen blur scenarios, particularly when leveraging appropriate expert selection.
77
+ [Daniel Feijoo](https://github.com/danifei), Paula Garrido-Mellado, Jaesung Rim, Álvaro García, Marcos V. Conde
78
+ [Fundación Cidaut](https://cidaut.ai/)
79
+ Available code at [github](https://github.com/cidautai/DeMoE). More information on the [Arxiv paper](https://arxiv.org/pdf/2508.06228).
80
+ > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
81
+ **This demo expects an image with some Low-Light degradations.**
82
+ <br>
83
+ '''
84
+
85
+ # Visible tasks in the UI
86
+ TASK_LABELS = ["Deblur", "Low-light", "movement", "defocus", "all"]
87
+
88
+ # Map pretty label -> internal task code used by the model
89
+ LABEL_TO_TASK = {
90
+ "Deblur": "global", # change to what your model expects for general deblurring
91
+ "Low-light": "lowlight",
92
+ "movement": "local", # if your model supports local motion blur
93
+ "defocus": "defocus", # if your model supports defocus blur
94
+ "all": "all", # if your model supports all types at once
95
+ }
96
+
97
+ css = """
98
+ .image-frame img, .image-container img {
99
+ width: auto;
100
+ height: auto;
101
+ max-width: none;
102
+ }
103
+ """
104
+
105
+ # Example lists per folder under ./assets (kept simple, no helpers)
106
+ exts = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
107
+
108
+ def list_basenames(folder):
109
+ """Return [[basename, task_label], ...] for gr.Examples using examples_dir."""
110
+ paths = sorted(glob.glob(f"assets/{folder}/*"))
111
+ basenames = [os.path.basename(p) for p in paths if p.lower().endswith(exts)]
112
+ # Default task per folder (tweak as you like)
113
+ default_task = "Low-light" if folder == "lowlight" else "Deblur"
114
+ return [[name, default_task] for name in basenames]
115
+
116
+ examples_agentir = list_basenames("AgentIR")
117
+ examples_allweather = list_basenames("allweather")
118
+ examples_amac = list_basenames("amac_examples")
119
+ examples_deblur = list_basenames("deblur")
120
+ examples_gestures = list_basenames("gestures")
121
+ examples_lowlight = list_basenames("lowlight")
122
+ examples_monolith = list_basenames("monolith")
123
+ examples_superres = list_basenames("superres")
124
+
125
+ def process_img(image, task_label='auto'):
126
+ """Main inference: converts PIL -> tensor, pads, runs the model with selected task, clamps, crops, returns PIL."""
127
+ task = LABEL_TO_TASK.get(task_label, 'auto') # default to lowlight if something unexpected arrives
128
+ tensor = pil_to_tensor(image).unsqueeze(0).to(device)
129
+ _, _, H, W = tensor.shape
130
+ tensor = pad_tensor(tensor)
131
+
132
+ with torch.no_grad():
133
+ output = model(tensor, task)
134
+
135
+ output = torch.clamp(output, 0., 1.)
136
+ output = output[:, :, :H, :W].squeeze(0)
137
+ return tensor_to_pil(output)
138
+
139
+ # -----------------------------
140
+ # Gradio Blocks layout
141
+ # -----------------------------
142
+ with gr.Blocks(css=css, title=title) as demo:
143
+ gr.Markdown(f"# {title}\n\n{description}")
144
+
145
+ with gr.Row():
146
+ # Input image and the task selector (Radio)
147
+ inp_img = gr.Image(type='pil', label='input')
148
+ # Output image and action button
149
+ out_img = gr.Image(type='pil', label='output')
150
+ task_selector = gr.Radio(
151
+ choices=TASK_LABELS,
152
+ value="auto",
153
+ label="Tipo de blur a corregir"
154
+ )
155
+
156
+ btn = gr.Button("Corregir", variant="primary")
157
+
158
+ # Connect the button to the inference function
159
+ btn.click(
160
+ fn=process_img,
161
+ inputs=[inp_img, task_selector],
162
+ outputs=[out_img]
163
+ )
164
+
165
+ # Examples grouped by folder (each item loads image + task automatically)
166
+ gr.Markdown("## Ejemplos (assets)")
167
+ with gr.Row():
168
+ # List folders found in ./assets
169
+ folders = list_subfolders("assets")
170
+ folder_radio = gr.Radio(choices=folders, label="Carpetas en assets", interactive=True)
171
+
172
+ gallery = gr.Gallery(
173
+ label="Imágenes de la carpeta seleccionada",
174
+ visible=False,
175
+ allow_preview=True,
176
+ columns=6,
177
+ height=320,
178
+ )
179
+
180
+ # State holds the current file list shown in the gallery (to resolve clicks)
181
+ current_files_state = gr.State([])
182
+
183
+ # When changing folder -> update gallery and state
184
+ folder_radio.change(
185
+ fn=update_gallery,
186
+ inputs=folder_radio,
187
+ outputs=[gallery, current_files_state]
188
+ )
189
+
190
+ # When clicking a thumbnail -> load it into the input image
191
+ gallery.select(
192
+ fn=load_from_gallery,
193
+ inputs=[current_files_state],
194
+ outputs=inp_img
195
+ )
196
+
197
+
198
+ if __name__ == '__main__':
199
+ # Explicit host/port and no SSR are friendly to Spaces
200
+ demo.launch(show_error=True, server_name="0.0.0.0", server_port=7864, ssr_mode=False)
examples/{1P0A1811.png → defocus/1P0A1811.png} RENAMED
File without changes
examples/{blur_4.png → global_motion/blur_4.png} RENAMED
File without changes
examples/{12_blur.png → local_motion/12_blur.png} RENAMED
File without changes
examples/{0031.png → low_light/0031.png} RENAMED
File without changes
examples/{000143.png → synth_global_motion/000143.png} RENAMED
File without changes