Anirudh Bhalekar commited on
Commit
b7ab39c
·
1 Parent(s): 86aacc4
Files changed (3) hide show
  1. __pycache__/inference.cpython-311.pyc +0 -0
  2. app.py +49 -5
  3. inference.py +34 -7
__pycache__/inference.cpython-311.pyc CHANGED
Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ
 
app.py CHANGED
@@ -4,9 +4,41 @@ import numpy as np
4
  import requests
5
  from PIL import Image
6
  import torch
7
- from inference import predict, random_sample, overlay_images
8
 
9
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  with gr.Blocks() as demo:
11
  # Button to select task
12
 
@@ -40,11 +72,23 @@ def main():
40
  overlay_button = gr.Button("Overlay Result", elem_id="overlay-button")
41
  overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image])
42
  gr.Markdown("### Post Processing")
43
- post_process = gr.Radio(choices=['Remove', 'Thresholding', 'Closing', 'Opening', 'Canny Edge', 'Gaussian Smoothing', 'Hysteresis'],
44
- value='None', elem_id="post-processing")
45
-
 
 
 
 
 
 
 
 
 
 
 
 
46
  gr.Button("Download Processed Image", elem_id="download-processed-button")
47
-
48
  demo.launch()
49
 
50
 
 
4
  import requests
5
  from PIL import Image
6
  import torch
7
+ from inference import predict, random_sample, overlay_images, post_process
8
 
9
  def main():
10
+
11
+ pp_options = [
12
+ "None",
13
+ "Thresholding",
14
+ "Closing",
15
+ "Opening",
16
+ "Canny Edge",
17
+ "Gaussian Smoothing",
18
+ "Hysteresis"
19
+ ]
20
+
21
+ def update_slider(post_process):
22
+ visibility = [0 for option in pp_options]
23
+ if post_process ==pp_options[0]:
24
+ # None
25
+ pass
26
+ else:
27
+ # Retrieve index of post_process
28
+ assert post_process in pp_options
29
+ index = pp_options.index(post_process)
30
+ visibility[index] = 1
31
+
32
+ ret_updates = []
33
+ for vis in visibility:
34
+ if vis == 1:
35
+ ret_updates.append(gr.update(visible=True))
36
+ else:
37
+ ret_updates.append(gr.update(visible=False))
38
+
39
+ return ret_updates
40
+
41
+
42
  with gr.Blocks() as demo:
43
  # Button to select task
44
 
 
72
  overlay_button = gr.Button("Overlay Result", elem_id="overlay-button")
73
  overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image])
74
  gr.Markdown("### Post Processing")
75
+ with gr.Row():
76
+ post_process = gr.Radio(choices=pp_options,
77
+ value='None', elem_id="post-processing", label="Post Processing Method")
78
+ slider_none = gr.Slider(minimum=0, maximum=255, value=128, label="None Value", visible=False)
79
+ slider_thresh = gr.Slider(minimum=0, maximum=255, value=128, label="Threshold Value", visible=False)
80
+ slider_close = gr.Slider(minimum=0, maximum=64, value=32, label="Closing Value", visible=False)
81
+ slider_open = gr.Slider(minimum=0, maximum=64, value=32, label="Opening Value", visible=False)
82
+ slider_canny = gr.Slider(minimum=0, maximum=255, value=128, label="Canny Edge Value", visible=False)
83
+ slider_gauss = gr.Slider(minimum=0, maximum=255, value=128, label="Sigma", visible=False)
84
+ slider_hyst = gr.Slider(minimum=0, maximum=255, value=128, label="Hysteresis Min Value", visible=False)
85
+ post_process.change(
86
+ fn=update_slider,
87
+ inputs=[post_process],
88
+ outputs=[slider_none, slider_thresh, slider_close, slider_open, slider_canny, slider_gauss, slider_hyst]
89
+ )
90
  gr.Button("Download Processed Image", elem_id="download-processed-button")
91
+
92
  demo.launch()
93
 
94
 
inference.py CHANGED
@@ -9,7 +9,7 @@ from huggingface_hub import hf_hub_download
9
  from PIL import Image
10
  import numpy as np
11
  from matplotlib import cm
12
-
13
 
14
  HFACE_FAULTS = "checkpoint-24.pth"
15
  HFACE_FACIES = "checkpoint-49.pth"
@@ -105,9 +105,36 @@ def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -
105
  return overlay_image
106
 
107
 
108
- def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None') -> Image:
109
- pass
110
-
111
-
112
- def apply_thresholding(prediction_image: Image) -> Image:
113
- return prediction_image.point(lambda p: p > 128 and 255)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from PIL import Image
10
  import numpy as np
11
  from matplotlib import cm
12
+ from PIL import ImageFilter
13
 
14
  HFACE_FAULTS = "checkpoint-24.pth"
15
  HFACE_FACIES = "checkpoint-49.pth"
 
105
  return overlay_image
106
 
107
 
108
+ def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None', value = None) -> Image:
109
+ if method == 'None':
110
+ return processed_prediction_image
111
+ elif method == 'Thresholding':
112
+ return apply_thresholding(processed_prediction_image)
113
+ elif method == 'Closing':
114
+ return apply_closing(processed_prediction_image, value)
115
+ elif method == 'Opening':
116
+ return apply_opening(processed_prediction_image, value)
117
+ elif method == 'Canny Edge':
118
+ return apply_canny_edge(processed_prediction_image, value)
119
+ elif method == 'Gaussian Smoothing':
120
+ return apply_gaussian_smoothing(processed_prediction_image, value)
121
+ elif method == 'Hysteresis':
122
+ return apply_hysteresis(processed_prediction_image, value)
123
+ else:
124
+ raise ValueError(f"Unknown post-processing method: {method}")
125
+
126
+
127
+ def apply_thresholding(image: Image, value: int) -> Image:
128
+ return image.point(lambda p: p > value and 255)
129
+ def apply_closing(image: Image, value: int) -> Image:
130
+ # Apply closing (dilation followed by erosion)
131
+ return image.filter(ImageFilter.MaxFilter(size=value)).filter(ImageFilter.MinFilter(size=value))
132
+ def apply_opening(image: Image, value: int) -> Image:
133
+ # Apply opening (erosion followed by dilation)
134
+ return image.filter(ImageFilter.MinFilter(size=value)).filter(ImageFilter.MaxFilter(size=value))
135
+ def apply_canny_edge(image: Image, value: int) -> Image:
136
+ return image.filter(ImageFilter.FIND_EDGES)
137
+ def apply_gaussian_smoothing(image: Image, value: int) -> Image:
138
+ return image.filter(ImageFilter.GaussianBlur(radius=value))
139
+ def apply_hysteresis(image: Image, value: int) -> Image:
140
+ return image.point(lambda p: p > value and 255) # Simple thresholding for hysteresis