Andrei Cozma commited on
Commit
50413ea
1 Parent(s): e6c0aeb
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+ # ruff
171
+ .ruff_cache/
172
+
173
+ # LSP config files
174
+ pyrightconfig.json
175
+
176
+ # End of https://www.toptal.com/developers/gitignore/api/python
177
+
178
+ .DS_Store
179
+ .idea
180
+ .vscode
LICENSE.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2023 Andrei Cozma
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
generate_samples.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from scipy.ndimage import gaussian_filter
6
+
7
+
8
+ def create_blank_image(size=256):
9
+ return np.zeros((size, size), dtype=np.uint8)
10
+
11
+
12
+ def add_horizontal_lines(image, spacing=10):
13
+ for i in range(0, image.shape[0], spacing):
14
+ image[i, :] = 255
15
+ return image
16
+
17
+
18
+ def add_vertical_lines(image, spacing=10):
19
+ for i in range(0, image.shape[1], spacing):
20
+ image[:, i] = 255
21
+ return image
22
+
23
+
24
+ def add_diagonal_lines(image, spacing=10):
25
+ for i in range(0, image.shape[0], spacing):
26
+ np.fill_diagonal(image[i:, i:], 255)
27
+ return image
28
+
29
+
30
+ # Add a circle to the image
31
+ def add_circle(image, radius=50):
32
+ center_x, center_y = image.shape[1] // 2, image.shape[0] // 2
33
+ y, x = np.ogrid[
34
+ -center_y : image.shape[0] - center_y, -center_x : image.shape[1] - center_x
35
+ ]
36
+ mask = x * x + y * y <= radius * radius
37
+ image[mask] = 255
38
+ return image
39
+
40
+
41
+ def add_checkerboard(image, square_size=16):
42
+ for i in range(0, image.shape[0], square_size * 2):
43
+ for j in range(0, image.shape[1], square_size * 2):
44
+ image[i : i + square_size, j : j + square_size] = 255
45
+ image[
46
+ i + square_size : i + square_size * 2,
47
+ j + square_size : j + square_size * 2,
48
+ ] = 255
49
+ return image
50
+
51
+
52
+ def add_horizontal_sinusoidal(image, frequency=1 / 20, amplitude=127):
53
+ x = np.linspace(0, 1, image.shape[1])
54
+ y = np.sin(2 * np.pi * frequency * x) * amplitude + 128
55
+ for i in range(image.shape[0]):
56
+ image[i, :] = y
57
+ return image
58
+
59
+
60
+ def add_vertical_sinusoidal(image, frequency=1 / 20, amplitude=127):
61
+ x = np.linspace(0, 1, image.shape[0])
62
+ y = np.sin(2 * np.pi * frequency * x) * amplitude + 128
63
+ for i in range(image.shape[1]):
64
+ image[:, i] = y
65
+ return image
66
+
67
+
68
+ def add_random_noise(image, std_dev=50):
69
+ noise = np.random.normal(0, std_dev, image.shape).astype(np.float32)
70
+ image = image.astype(np.float32) + noise
71
+ np.clip(image, 0, 255, out=image)
72
+ image = image.astype(np.uint8)
73
+ return image
74
+
75
+
76
+ def save_image(image, filename):
77
+ image = Image.fromarray(image)
78
+ image.save(filename)
79
+
80
+
81
+ savedir = "./images/"
82
+
83
+ os.makedirs(savedir, exist_ok=True)
84
+
85
+ save_image(create_blank_image(), savedir + "blank.png")
86
+ save_image(add_horizontal_lines(create_blank_image()), savedir + "horizontal.png")
87
+ save_image(add_vertical_lines(create_blank_image()), savedir + "vertical.png")
88
+ save_image(add_diagonal_lines(create_blank_image()), savedir + "diagonal.png")
89
+ save_image(add_circle(create_blank_image()), savedir + "circle.png")
90
+ save_image(add_checkerboard(create_blank_image()), savedir + "checkerboard.png")
91
+ save_image(
92
+ add_horizontal_sinusoidal(create_blank_image()),
93
+ savedir + "horizontal_sinusoidal.png",
94
+ )
95
+ save_image(
96
+ add_vertical_sinusoidal(create_blank_image()), savedir + "vertical_sinusoidal.png"
97
+ )
98
+ save_image(add_random_noise(create_blank_image()), savedir + "random_noise.png")
images/blank.png ADDED
images/checkerboard.png ADDED
images/circle.png ADDED
images/diagonal.png ADDED
images/horizontal.png ADDED
images/horizontal_sinusoidal.png ADDED
images/random_noise.png ADDED
images/vertical.png ADDED
images/vertical_sinusoidal.png ADDED
main.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, Union
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image, ImageChops, ImageOps
7
+
8
+
9
+ class ImageInfo:
10
+ def __init__(
11
+ self,
12
+ size: Tuple[int, int],
13
+ channels: int,
14
+ data_type: str,
15
+ min_val: float,
16
+ max_val: float,
17
+ ):
18
+ self.size = size
19
+ self.channels = channels
20
+ self.data_type = data_type
21
+ self.min_val = min_val
22
+ self.max_val = max_val
23
+
24
+ @classmethod
25
+ def from_pil(cls, pil_image: Image.Image) -> "ImageInfo":
26
+ size = (pil_image.width, pil_image.height)
27
+ channels = len(pil_image.getbands())
28
+ data_type = str(pil_image.mode)
29
+ extrema = pil_image.getextrema()
30
+ if channels > 1: # Multi-band image
31
+ min_val = min([band[0] for band in extrema])
32
+ max_val = max([band[1] for band in extrema])
33
+ else: # Single-band image
34
+ min_val, max_val = extrema
35
+ return cls(size, channels, data_type, min_val, max_val)
36
+
37
+ @classmethod
38
+ def from_numpy(cls, np_array: np.ndarray) -> "ImageInfo":
39
+ if len(np_array.shape) > 3:
40
+ raise ValueError(f"Unsupported array shape: {np_array.shape}")
41
+ size = (np_array.shape[1], np_array.shape[0])
42
+ channels = 1 if len(np_array.shape) == 2 else np_array.shape[2]
43
+ data_type = str(np_array.dtype)
44
+ min_val, max_val = np_array.min(), np_array.max()
45
+ return cls(size, channels, data_type, min_val, max_val)
46
+
47
+ @classmethod
48
+ def from_any(cls, image: Union[Image.Image, np.ndarray]) -> "ImageInfo":
49
+ if isinstance(image, np.ndarray):
50
+ return cls.from_numpy(image)
51
+ elif isinstance(image, Image.Image):
52
+ return cls.from_pil(image)
53
+ else:
54
+ raise ValueError(f"Unsupported image type: {type(image)}")
55
+
56
+ def __str__(self) -> str:
57
+ return f"{str(self.size)} {self.channels}C {self.data_type} {round(self.min_val, 2)}min/{round(self.max_val, 2)}max"
58
+
59
+ @property
60
+ def aspect_ratio(self) -> float:
61
+ return self.size[0] / self.size[1]
62
+
63
+
64
+ def nextpow2(n):
65
+ """Find the next power of 2 greater than or equal to `n`."""
66
+ return int(2 ** np.ceil(np.log2(n)))
67
+
68
+
69
+ def pad_image_nextpow2(image):
70
+ print("-" * 80)
71
+ print("pad_image_nextpow2: ")
72
+ print(ImageInfo.from_any(image))
73
+
74
+ if image.ndim == 2:
75
+ image = image[:, :, np.newaxis]
76
+
77
+ assert image.ndim == 3, f"Expected image.ndim == 3. Got {image.ndim}"
78
+
79
+ height, width, channels = image.shape
80
+ height_new = nextpow2(height)
81
+ width_new = nextpow2(width)
82
+
83
+ height_diff = height_new - height
84
+ width_diff = width_new - width
85
+
86
+ image = np.pad(
87
+ image,
88
+ (
89
+ (height_diff // 2, height_diff - height_diff // 2),
90
+ (width_diff // 2, width_diff - width_diff // 2),
91
+ (0, 0),
92
+ ),
93
+ mode="constant",
94
+ # mode="edge",
95
+ # mode="linear_ramp",
96
+ # mode="maximum",
97
+ # mode="mean",
98
+ # mode="median",
99
+ # mode="minimum",
100
+ # mode="reflect",
101
+ # mode="symmetric",
102
+ # mode="wrap",
103
+ # mode="empty",
104
+ )
105
+
106
+ print(ImageInfo.from_any(image))
107
+
108
+ return image
109
+
110
+
111
+ def get_fft(image):
112
+ print("-" * 80)
113
+ print("get_fft: ")
114
+ print("image:", ImageInfo.from_any(image))
115
+
116
+ fft = np.fft.fft2(image, axes=(0, 1, 2))
117
+ fft = np.fft.fftshift(fft)
118
+
119
+ return fft
120
+
121
+
122
+ def get_ifft_image(fft):
123
+ print("-" * 80)
124
+ print("get_ifft_image: ")
125
+
126
+ ifft = np.fft.ifftshift(fft)
127
+ ifft = np.fft.ifft2(ifft, axes=(0, 1, 2))
128
+
129
+ # we only need the real part
130
+ ifft_image = np.real(ifft)
131
+
132
+ # remove padding
133
+ # ifft = ifft[
134
+ # h_diff // 2 : h_diff // 2 + original_shape[0],
135
+ # w_diff // 2 : w_diff // 2 + original_shape[1],
136
+ # ]
137
+
138
+ ifft_image = (ifft_image - np.min(ifft_image)) / (
139
+ np.max(ifft_image) - np.min(ifft_image)
140
+ )
141
+ ifft_image = ifft_image * 255
142
+ ifft_image = ifft_image.astype(np.uint8)
143
+
144
+ return ifft_image
145
+
146
+
147
+ def fft_mag_image(fft):
148
+ print("-" * 80)
149
+ print("fft_mag_image: ")
150
+
151
+ fft_mag = np.abs(fft)
152
+ fft_mag = np.log(fft_mag + 1)
153
+
154
+ # scale 0 to 1
155
+ fft_mag = (fft_mag - np.min(fft_mag)) / (np.max(fft_mag) - np.min(fft_mag) + 1e-6)
156
+ # scale to (0, 255)
157
+ fft_mag = fft_mag * 255
158
+ fft_mag = fft_mag.astype(np.uint8)
159
+ return fft_mag
160
+
161
+
162
+ def fft_phase_image(fft):
163
+ print("-" * 80)
164
+ print("fft_phase_image: ")
165
+
166
+ fft_phase = np.angle(fft)
167
+ fft_phase = fft_phase + np.pi
168
+ fft_phase = fft_phase / (2 * np.pi)
169
+
170
+ # scale 0 to 1
171
+ fft_phase = (fft_phase - np.min(fft_phase)) / (
172
+ np.max(fft_phase) - np.min(fft_phase)
173
+ )
174
+ # scale to (0, 255)
175
+ fft_phase = fft_phase * 255
176
+ fft_phase = fft_phase.astype(np.uint8)
177
+ return fft_phase
178
+
179
+
180
+ def onclick_process_fft(state, inp_image, mask_opacity, inverted_mask, pad):
181
+ print("-" * 80)
182
+ print("onclick_process_fft:")
183
+
184
+ if isinstance(inp_image, dict):
185
+ if "image" not in inp_image:
186
+ raise gr.Error("Please upload or select an image first.")
187
+
188
+ image, mask = inp_image["image"], inp_image["mask"]
189
+ print("image:", ImageInfo.from_any(image))
190
+ print("mask:", ImageInfo.from_any(image))
191
+
192
+ image = Image.fromarray(image)
193
+ mask = Image.fromarray(mask)
194
+
195
+ if not inverted_mask:
196
+ mask = ImageOps.invert(mask)
197
+
198
+ image_final = ImageChops.multiply(image, mask)
199
+ image_final = Image.blend(image, image_final, mask_opacity)
200
+
201
+ image_final = image_final.convert(image.mode)
202
+ image_final = np.array(image_final)
203
+ elif isinstance(inp_image, np.ndarray):
204
+ image_final = inp_image
205
+ else:
206
+ raise gr.Error("Please upload or select an image first.")
207
+
208
+ print("image_final:", ImageInfo.from_any(image_final))
209
+
210
+ if pad:
211
+ image_final = pad_image_nextpow2(image_final)
212
+
213
+ state["inp_image"] = image_final
214
+
215
+ image_mag = fft_mag_image(get_fft(image_final))
216
+ image_phase = fft_phase_image(get_fft(image_final))
217
+
218
+ return (
219
+ [
220
+ (image_final, "Input Image (Final)"),
221
+ (image_mag, "FFT Magnitude (Original)"),
222
+ (image_phase, "FFT Phase (Original)"),
223
+ ],
224
+ image_mag,
225
+ image_phase,
226
+ )
227
+
228
+
229
+ def onclick_process_ifft(state, mag_and_mask, phase_and_mask):
230
+ print("-" * 80)
231
+ print("onclick_process_ifft:")
232
+ if state["inp_image"] is None:
233
+ raise gr.Error("Please process FFT first.")
234
+
235
+ image = state["inp_image"]
236
+ # h_new = nextpow2(original_shape[0])
237
+ # w_new = nextpow2(original_shape[1])
238
+ # h_diff = h_new - original_shape[0]
239
+ # w_diff = w_new - original_shape[1]
240
+
241
+ mask_mag = mag_and_mask["mask"]
242
+ print("mag_mask:", ImageInfo.from_any(mask_mag))
243
+
244
+ mask_phase = phase_and_mask["mask"]
245
+ print("phase_mask:", ImageInfo.from_any(mask_phase))
246
+
247
+ fft = get_fft(state["inp_image"])
248
+ print(f"fft: {fft.shape}")
249
+
250
+ fft_mag = np.where(mask_mag == 255, 0, np.abs(fft))
251
+ fft_phase = np.where(mask_phase == 255, 0, np.angle(fft))
252
+
253
+ fft = fft_mag * np.exp(1j * fft_phase)
254
+
255
+ ifft_image = get_ifft_image(fft)
256
+ image_mag = fft_mag_image(fft)
257
+ image_phase = fft_phase_image(fft)
258
+
259
+ return (
260
+ [
261
+ (image, "Input Image (Final)"),
262
+ (image_mag, "FFT Magnitude (Filtered)"),
263
+ (image_phase, "FFT Phase (Filtered)"),
264
+ ],
265
+ ifft_image,
266
+ )
267
+
268
+
269
+ def get_start_image():
270
+ return (np.ones((512, 512, 3)) * 255).astype(np.uint8)
271
+
272
+
273
+ def update_image_input(state, selection):
274
+ print("-" * 80)
275
+ print("update_image_input:")
276
+ print(f"selection: {selection}")
277
+ if not selection:
278
+ white_image = get_start_image()
279
+ return (
280
+ white_image,
281
+ [white_image],
282
+ None,
283
+ None,
284
+ None,
285
+ )
286
+
287
+ image_path = os.path.join("./images", selection)
288
+ print(f"image_path: {image_path}")
289
+ if not os.path.exists(image_path):
290
+ raise gr.Error(f"Image not found: {image_path}")
291
+
292
+ image = Image.open(image_path)
293
+ image = np.array(image)
294
+ state["inp_image"] = image
295
+ return (
296
+ image,
297
+ [image],
298
+ None,
299
+ None,
300
+ None,
301
+ )
302
+
303
+
304
+ def clear_image_input(state):
305
+ print("-" * 80)
306
+ print("clear_image_input:")
307
+ state["inp_image"] = None
308
+ return (
309
+ None,
310
+ [],
311
+ None,
312
+ None,
313
+ None,
314
+ )
315
+
316
+
317
+ css = """
318
+ .fft_mag > .image-container > button > div:first-child {
319
+ display: none;
320
+ }
321
+ .fft_phase > .image-container > button > div:first-child {
322
+ display: none;
323
+ }
324
+ .ifft_img > .image-container > button > div:first-child {
325
+ display: none;
326
+ }
327
+ """
328
+
329
+ with gr.Blocks(css=css) as demo:
330
+ state = gr.State(
331
+ {
332
+ "inp_image": None,
333
+ },
334
+ )
335
+
336
+ with gr.Row():
337
+ with gr.Column():
338
+ inp_image = gr.Image(
339
+ value=get_start_image(),
340
+ label="Input Image",
341
+ height=512,
342
+ type="numpy",
343
+ interactive=True,
344
+ tool="sketch",
345
+ mask_opacity=1.0,
346
+ elem_classes=["inp_img"],
347
+ )
348
+ files = os.listdir("./images")
349
+ files = sorted(files)
350
+ inp_samples = gr.Dropdown(
351
+ choices=files,
352
+ label="Select Example Image",
353
+ )
354
+
355
+ with gr.Column():
356
+ out_gallery = gr.Gallery(
357
+ label="Input Gallery",
358
+ height=512,
359
+ rows=1,
360
+ columns=3,
361
+ allow_preview=True,
362
+ preview=False,
363
+ selected_index=None,
364
+ )
365
+
366
+ with gr.Row():
367
+ inp_mask_opacity = gr.Slider(
368
+ label="Mask Opacity",
369
+ minimum=0.0,
370
+ maximum=1.0,
371
+ step=0.05,
372
+ value=1.0,
373
+ )
374
+
375
+ inp_invert_mask = gr.Checkbox(
376
+ label="Invert Mask",
377
+ value=False,
378
+ )
379
+
380
+ inp_pad = gr.Checkbox(
381
+ label="Pad NextPow2",
382
+ value=True,
383
+ )
384
+
385
+ btn_fft = gr.Button("Process FFT")
386
+
387
+ out_fft_mag = gr.Image(
388
+ label="FFT Magnitude Spectrum",
389
+ height=512,
390
+ type="numpy",
391
+ interactive=True,
392
+ # source="canvas",
393
+ tool="sketch",
394
+ mask_opacity=1.0,
395
+ elem_classes=["fft_mag"],
396
+ )
397
+ out_fft_phase = gr.Image(
398
+ label="FFT Phase Spectrum",
399
+ height=512,
400
+ type="numpy",
401
+ interactive=True,
402
+ # source="canvas",
403
+ tool="sketch",
404
+ mask_opacity=1.0,
405
+ elem_classes=["fft_phase"],
406
+ )
407
+
408
+ btn_ifft = gr.Button("Process IFFT")
409
+
410
+ out_ifft = gr.Image(
411
+ label="IFFT",
412
+ height=512,
413
+ type="numpy",
414
+ interactive=True,
415
+ show_download_button=True,
416
+ elem_classes=["ifft_img"],
417
+ )
418
+
419
+ inp_image.clear(
420
+ clear_image_input,
421
+ [state],
422
+ [inp_samples, out_gallery, out_fft_mag, out_fft_phase, out_ifft],
423
+ )
424
+
425
+ # Set up event listener for the Dropdown component to update the image input
426
+ inp_samples.change(
427
+ update_image_input,
428
+ [state, inp_samples],
429
+ [inp_image, out_gallery, out_fft_mag, out_fft_phase, out_ifft],
430
+ )
431
+
432
+ # Set up events for fft processing
433
+ btn_fft.click(
434
+ onclick_process_fft,
435
+ [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
436
+ [out_gallery, out_fft_mag, out_fft_phase],
437
+ )
438
+
439
+ out_fft_mag.clear(
440
+ onclick_process_fft,
441
+ [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
442
+ [out_gallery, out_fft_mag, out_fft_phase],
443
+ )
444
+
445
+ out_fft_phase.clear(
446
+ onclick_process_fft,
447
+ [state, inp_image, inp_mask_opacity, inp_invert_mask, inp_pad],
448
+ [out_gallery, out_fft_mag, out_fft_phase],
449
+ )
450
+
451
+ # inp_image.edit(
452
+ # get_fft_images,
453
+ # [state, inp_image],
454
+ # [out_gallery, out_fft_mag, out_fft_phase],
455
+ # )
456
+
457
+ # Set up events for ifft processing
458
+ btn_ifft.click(
459
+ onclick_process_ifft,
460
+ [state, out_fft_mag, out_fft_phase],
461
+ [out_gallery, out_ifft],
462
+ )
463
+
464
+ # out_fft_mag.edit(
465
+ # get_ifft_image,
466
+ # [state, out_fft_mag, out_fft_phase],
467
+ # [out_ifft],
468
+ # )
469
+
470
+ # out_fft_phase.edit(
471
+ # get_ifft_image,
472
+ # [state, out_fft_mag, out_fft_phase],
473
+ # [out_ifft],
474
+ # )
475
+
476
+ if __name__ == "__main__":
477
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ numpy
3
+ Pillow