Anirudh Bhalekar commited on
Commit
a81555c
·
1 Parent(s): 1fc8c87
.env DELETED
@@ -1 +0,0 @@
1
- PYTHONPATH="C:\Users\abhalekar\AppData\Local\Microsoft\WindowsApps\python3.11.exe"
 
 
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
__pycache__/inference.cpython-311.pyc CHANGED
Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ
 
app.py CHANGED
@@ -4,20 +4,24 @@ import numpy as np
4
  import requests
5
  from PIL import Image
6
  import torch
7
- from inference import predict, random_sample
8
 
9
  def main():
10
  with gr.Blocks() as demo:
11
  # Button to select task
12
 
13
  seismic_data = gr.State()
 
 
 
14
  gr.Markdown("## SFM Inference Demo")
 
15
  with gr.Row():
16
  task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault')
17
  gr.Markdown("### Upload your seismic data or sample from dataset")
18
 
19
  with gr.Row():
20
- seismic_image = gr.Image(label="Upload Seismic Data")
21
  prediction_image = gr.Image(label="Prediction Result")
22
 
23
  with gr.Row():
@@ -26,6 +30,26 @@ def main():
26
 
27
  with gr.Row():
28
  predict_button = gr.Button("Run Inference", elem_id="predict-button")
29
- predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- demo.launch()
 
 
4
  import requests
5
  from PIL import Image
6
  import torch
7
+ from inference import predict, random_sample, overlay_images
8
 
9
  def main():
10
  with gr.Blocks() as demo:
11
  # Button to select task
12
 
13
  seismic_data = gr.State()
14
+ prediction_data = gr.State()
15
+ processed_prediction_data = gr.State()
16
+
17
  gr.Markdown("## SFM Inference Demo")
18
+ gr.Markdown("### Select a task and run inference on seismic data")
19
  with gr.Row():
20
  task = gr.Radio(choices=['Fault', 'Facies'], label="Select Task", value='Fault')
21
  gr.Markdown("### Upload your seismic data or sample from dataset")
22
 
23
  with gr.Row():
24
+ seismic_image = gr.Image(label="Seismic Data")
25
  prediction_image = gr.Image(label="Prediction Result")
26
 
27
  with gr.Row():
 
30
 
31
  with gr.Row():
32
  predict_button = gr.Button("Run Inference", elem_id="predict-button")
33
+ predict_button.click(fn=predict, inputs=[seismic_data, task], outputs=[prediction_image, prediction_data])
34
+ processed_prediction_data = prediction_data
35
+
36
+ with gr.Row():
37
+ overlay_image = gr.Image(label="Overlay Result")
38
+ with gr.Column():
39
+ gr.Markdown("### Overlay Seismic Data with Prediction Result")
40
+ overlay_button = gr.Button("Overlay Result", elem_id="overlay-button")
41
+ overlay_button.click(fn=overlay_images, inputs=[seismic_image, prediction_image], outputs=[overlay_image])
42
+ gr.Markdown("### Post Processing")
43
+ post_process = gr.Radio(choices=['Remove', 'Thresholding', 'Closing', 'Opening', 'Canny Edge', 'Gaussian Smoothing', 'Hysteresis'],
44
+ value='None', elem_id="post-processing")
45
+
46
+ post_process.change(fn=show_slider, inputs=[post_process], outputs=[])
47
+
48
+
49
+ gr.Button("Download Processed Image", elem_id="download-processed-button")
50
+
51
+ demo.launch()
52
+
53
 
54
+ if __name__ == "__main__":
55
+ main()
inference.py CHANGED
@@ -67,11 +67,11 @@ def predict(seismic: torch.Tensor, task='Fault', model_type='vit_large_patch16',
67
  output = output.argmax(dim=0)
68
  output = output.detach().cpu().numpy()
69
 
70
- print("Model output shape:", output.shape)
71
- output = output/ output.max() # Normalize output to [0, 1] range
72
  # output is numpy 2d array - convert to pil RGB image
73
- output = Image.fromarray((output * 255).astype(np.uint8)).convert("RGB")
74
- return output
 
75
 
76
 
77
  def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0):
@@ -91,4 +91,44 @@ def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0)
91
  seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range
92
  seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image
93
 
94
- return seis_image, seis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  output = output.argmax(dim=0)
68
  output = output.detach().cpu().numpy()
69
 
70
+ output_image = output/ output.max() # Normalize output to [0, 1] range
 
71
  # output is numpy 2d array - convert to pil RGB image
72
+ output_image = Image.fromarray((output_image * 255).astype(np.uint8)).convert("RGB")
73
+
74
+ return output_image, output
75
 
76
 
77
  def random_sample(task = 'Fault', data_path = None, batch_size=1, num_workers=0):
 
91
  seis_image = (seis_image - seis_image.min()) / (seis_image.max() - seis_image.min()) # Normalize to [0, 1] range
92
  seis_image = Image.fromarray(np.uint8(cm.seismic(seis_image) * 255)) # Convert to PIL Image
93
 
94
+ return seis_image, seis
95
+
96
+
97
+ def overlay_images(seismic_image: Image, prediction_image: Image, alpha = 0.5) -> Image:
98
+ # Create an overlay of the predicted facies/faults on the original seismic image
99
+ prediction_image = Image.fromarray(np.array(prediction_image).astype(np.uint8)).convert("RGBA")
100
+ seismic_image = Image.fromarray(np.array(seismic_image).astype(np.uint8)).convert("RGBA")
101
+
102
+ prediction_image.putalpha(int(255 * alpha)) # Set alpha for overlay
103
+
104
+ overlay_image = Image.alpha_composite(seismic_image, prediction_image)
105
+ return overlay_image
106
+
107
+
108
+ def post_process(processed_prediction_image: Image, prediction_image: Image, method: str = 'None') -> Image:
109
+ if method == 'Remove':
110
+ # Remove the predicted regions from the original image
111
+ output = prediction_image.copy()
112
+ elif method == 'Thresholding':
113
+ # Apply thresholding to the predicted image
114
+ output = prediction_image.point(lambda p: p > 128 and 255)
115
+ elif method == 'Closing':
116
+ # Apply closing operation to the predicted image
117
+ output = prediction_image.filter(ImageFilter.MaxFilter(size=5))
118
+ elif method == 'Opening':
119
+ # Apply opening operation to the predicted image
120
+ output = prediction_image.filter(ImageFilter.MinFilter(size=5))
121
+ elif method == 'Canny Edge':
122
+ # Apply Canny edge detection to the predicted image
123
+ output = prediction_image.filter(ImageFilter.FIND_EDGES)
124
+ elif method == 'Gaussian Smoothing':
125
+ # Apply Gaussian smoothing to the predicted image
126
+ output = prediction_image.filter(ImageFilter.GaussianBlur(radius=2))
127
+ elif method == 'Hysteresis':
128
+ # Apply hysteresis thresholding to the predicted image
129
+ output = prediction_image.point(lambda p: p > 128 and 255)
130
+ return output
131
+
132
+
133
+ def apply_thresholding(prediction_image: Image) -> Image:
134
+ return prediction_image.point(lambda p: p > 128 and 255)
test.ipynb CHANGED
@@ -2,39 +2,10 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "aaadf81b",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stderr",
11
- "output_type": "stream",
12
- "text": [
13
- "C:\\Users\\abhalekar\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python311\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
- " from .autonotebook import tqdm as notebook_tqdm\n"
15
- ]
16
- },
17
- {
18
- "name": "stdout",
19
- "output_type": "stream",
20
- "text": [
21
- "* Running on local URL: http://127.0.0.1:7860\n",
22
- "* To create a public link, set `share=True` in `launch()`.\n"
23
- ]
24
- },
25
- {
26
- "data": {
27
- "text/html": [
28
- "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
29
- ],
30
- "text/plain": [
31
- "<IPython.core.display.HTML object>"
32
- ]
33
- },
34
- "metadata": {},
35
- "output_type": "display_data"
36
- }
37
- ],
38
  "source": [
39
  "import app\n",
40
  "\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "aaadf81b",
7
  "metadata": {},
8
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import app\n",
11
  "\n",