Sadjad Alikhani commited on
Commit
6e72601
·
verified ·
1 Parent(s): 469d918

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -115
app.py CHANGED
@@ -15,8 +15,9 @@ import matplotlib.pyplot as plt
15
  RAW_PATH = os.path.join("images", "raw")
16
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
17
 
18
- # Specific values for percentage of data for training
19
  percentage_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
 
20
 
21
  # Custom class to capture print output
22
  class PrintCapture(io.StringIO):
@@ -32,30 +33,18 @@ class PrintCapture(io.StringIO):
32
  return ''.join(self.output)
33
 
34
  # Function to load and display predefined images based on user selection
35
- def display_predefined_images(percentage_idx):
36
  percentage = percentage_values[percentage_idx]
37
- raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_16.png") # Assume complexity 16 for simplicity
38
- embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_16.png")
 
 
39
 
40
  raw_image = Image.open(raw_image_path)
41
  embeddings_image = Image.open(embeddings_image_path)
42
 
43
  return raw_image, embeddings_image
44
 
45
- # Function to create random images for LoS/NLoS classification results
46
- def create_random_image(size=(300, 300)):
47
- random_image = np.random.rand(*size, 3) * 255
48
- return Image.fromarray(random_image.astype('uint8'))
49
-
50
- # Function to load the pre-trained model from your cloned repository
51
- def load_custom_model():
52
- from lwm_model import LWM # Assuming the model is defined in lwm_model.py
53
- model = LWM() # Modify this according to your model initialization
54
- model.eval()
55
- return model
56
-
57
- import importlib.util
58
-
59
  # Function to dynamically load a Python module from a given file path
60
  def load_module_from_path(module_name, file_path):
61
  spec = importlib.util.spec_from_file_location(module_name, file_path)
@@ -80,25 +69,18 @@ def split_dataset(channels, labels, percentage_idx):
80
 
81
  return train_data, test_data, train_labels, test_labels
82
 
83
- # Function to calculate Euclidean distance between a point and a centroid
84
- def euclidean_distance(x, centroid):
85
- return np.linalg.norm(x - centroid)
86
-
87
- import torch
88
-
89
  def classify_based_on_distance(train_data, train_labels, test_data):
90
- # Compute the centroids for the two classes
91
- centroid_0 = train_data[train_labels == 0].mean(dim=0) # Use torch.mean
92
- centroid_1 = train_data[train_labels == 1].mean(dim=0) # Use torch.mean
93
 
94
  predictions = []
95
  for test_point in test_data:
96
- # Compute Euclidean distance between the test point and each centroid
97
- dist_0 = euclidean_distance(test_point, centroid_0)
98
- dist_1 = euclidean_distance(test_point, centroid_1)
99
  predictions.append(0 if dist_0 < dist_1 else 1)
100
 
101
- return torch.tensor(predictions) # Return predictions as a PyTorch tensor
102
 
103
  # Function to generate confusion matrix plot
104
  def plot_confusion_matrix(y_true, y_pred, title):
@@ -115,34 +97,10 @@ def plot_confusion_matrix(y_true, y_pred, title):
115
  plt.savefig(f"{title}.png")
116
  return Image.open(f"{title}.png")
117
 
118
- def identical_train_test_split(output_emb, output_raw, labels, percentage):
119
- N = output_emb.shape[0] # Get the total number of samples
120
-
121
- # Generate the indices for shuffling and splitting
122
- indices = torch.randperm(N) # Randomly shuffle the indices
123
-
124
- # Calculate the split index
125
- split_index = int(N * percentage)
126
-
127
- # Split indices into train and test
128
- train_indices = indices[:split_index] # First 80% for training
129
- test_indices = indices[split_index:] # Remaining 20% for testing
130
-
131
- # Select the same indices from both output_emb and output_raw
132
- train_emb = output_emb[train_indices]
133
- test_emb = output_emb[test_indices]
134
-
135
- train_raw = output_raw[train_indices]
136
- test_raw = output_raw[test_indices]
137
-
138
- train_labels = labels[train_indices]
139
- test_labels = labels[test_indices]
140
-
141
- return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
142
-
143
  # Store the original working directory when the app starts
144
  original_dir = os.getcwd()
145
 
 
146
  def process_hdf5_file(uploaded_file, percentage_idx):
147
  capture = PrintCapture()
148
  sys.stdout = capture # Redirect print statements to capture
@@ -153,76 +111,45 @@ def process_hdf5_file(uploaded_file, percentage_idx):
153
 
154
  # Step 1: Clone the repository if not already done
155
  if not os.path.exists(model_repo_dir):
156
- print(f"Cloning model repository from {model_repo_url}...")
157
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
158
 
159
- # Step 2: Verify the repository was cloned and change the working directory
160
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
161
  if os.path.exists(repo_work_dir):
162
- os.chdir(repo_work_dir) # Change the working directory only once
163
- print(f"Changed working directory to {os.getcwd()}")
164
- print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
165
  else:
166
  print(f"Directory {repo_work_dir} does not exist.")
167
  return
168
-
169
- # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
170
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
171
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
172
  inference_path = os.path.join(os.getcwd(), 'inference.py')
173
 
174
- # Load lwm_model
175
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
176
-
177
- # Load input_preprocess
178
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
179
-
180
- # Load inference
181
  inference = load_module_from_path("inference", inference_path)
182
 
183
- # Step 4: Load the model from lwm_model module
184
  device = 'cpu'
185
- print(f"Loading the LWM model on {device}...")
186
  model = lwm_model.LWM.from_pretrained(device=device)
187
 
188
- # Step 5: Load the HDF5 file and extract the channels and labels
189
  with h5py.File(uploaded_file.name, 'r') as f:
190
- channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
191
- labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
192
- print(f"Loaded dataset with {channels.shape[0]} samples.")
193
 
194
- # Step 7: Tokenize the data using the tokenizer from input_preprocess
195
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
196
 
197
- # Step 7: Perform inference using the functions from inference.py
198
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
199
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
200
 
201
- print(f"Output Embeddings Shape: {output_emb.shape}")
202
- print(f"Output Raw Shape: {output_raw.shape}")
203
-
204
- train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
205
- output_raw.view(len(output_raw),-1),
206
- labels,
207
- percentage_idx)
208
-
209
- # Step 8: Perform classification using the Euclidean distance for both raw and embeddings
210
- pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
211
- pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
212
-
213
- # Step 9: Generate confusion matrices for both raw and embeddings
214
- raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
215
- emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
216
-
217
- return raw_cm_image, emb_cm_image, capture.get_output()
218
 
219
  except Exception as e:
220
  return str(e), str(e), capture.get_output()
221
 
222
  finally:
223
- # Always return to the original working directory after processing
224
  os.chdir(original_dir)
225
- sys.stdout = sys.__stdout__ # Reset print statements
226
 
227
  # Function to handle logic based on whether a file is uploaded or not
228
  def los_nlos_classification(file, percentage_idx):
@@ -231,18 +158,19 @@ def los_nlos_classification(file, percentage_idx):
231
  else:
232
  return display_predefined_images(percentage_idx), None
233
 
234
- # Define the Gradio interface
235
  with gr.Blocks(css="""
236
- .vertical-slider input[type=range] {
237
- writing-mode: bt-lr; /* IE */
238
- -webkit-appearance: slider-vertical; /* WebKit */
239
- width: 8px;
240
- height: 200px;
241
- }
242
  .slider-container {
243
- display: inline-block;
244
- margin-right: 50px;
245
  text-align: center;
 
 
 
 
 
 
 
 
 
246
  }
247
  """) as demo:
248
 
@@ -265,13 +193,17 @@ with gr.Blocks(css="""
265
  with gr.Row():
266
  with gr.Column(elem_id="slider-container"):
267
  gr.Markdown("Percentage of Data for Training")
268
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
269
-
270
- with gr.Row():
271
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
272
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
273
 
274
- percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
 
 
 
 
 
275
 
276
  with gr.Tab("LoS/NLoS Classification Task"):
277
  gr.Markdown("### LoS/NLoS Classification Task")
@@ -281,12 +213,12 @@ with gr.Blocks(css="""
281
  with gr.Row():
282
  with gr.Column(elem_id="slider-container"):
283
  gr.Markdown("Percentage of Data for Training")
284
- percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
285
 
286
- with gr.Row():
287
- raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
288
- embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
289
- output_textbox = gr.Textbox(label="Console Output", lines=10)
290
 
291
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
292
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
15
  RAW_PATH = os.path.join("images", "raw")
16
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
17
 
18
+ # Specific values for percentage of data for training and task complexity
19
  percentage_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
20
+ complexity_values = [16, 32, 64, 128, 256] # Task complexity values
21
 
22
  # Custom class to capture print output
23
  class PrintCapture(io.StringIO):
 
33
  return ''.join(self.output)
34
 
35
  # Function to load and display predefined images based on user selection
36
+ def display_predefined_images(percentage_idx, complexity_idx):
37
  percentage = percentage_values[percentage_idx]
38
+ complexity = complexity_values[complexity_idx]
39
+
40
+ raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
41
+ embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
42
 
43
  raw_image = Image.open(raw_image_path)
44
  embeddings_image = Image.open(embeddings_image_path)
45
 
46
  return raw_image, embeddings_image
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Function to dynamically load a Python module from a given file path
49
  def load_module_from_path(module_name, file_path):
50
  spec = importlib.util.spec_from_file_location(module_name, file_path)
 
69
 
70
  return train_data, test_data, train_labels, test_labels
71
 
72
+ # Function to classify based on distance to class centroids
 
 
 
 
 
73
  def classify_based_on_distance(train_data, train_labels, test_data):
74
+ centroid_0 = train_data[train_labels == 0].mean(dim=0)
75
+ centroid_1 = train_data[train_labels == 1].mean(dim=0)
 
76
 
77
  predictions = []
78
  for test_point in test_data:
79
+ dist_0 = torch.norm(test_point - centroid_0)
80
+ dist_1 = torch.norm(test_point - centroid_1)
 
81
  predictions.append(0 if dist_0 < dist_1 else 1)
82
 
83
+ return torch.tensor(predictions)
84
 
85
  # Function to generate confusion matrix plot
86
  def plot_confusion_matrix(y_true, y_pred, title):
 
97
  plt.savefig(f"{title}.png")
98
  return Image.open(f"{title}.png")
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # Store the original working directory when the app starts
101
  original_dir = os.getcwd()
102
 
103
+ # Function to process the uploaded HDF5 file for LoS/NLoS classification
104
  def process_hdf5_file(uploaded_file, percentage_idx):
105
  capture = PrintCapture()
106
  sys.stdout = capture # Redirect print statements to capture
 
111
 
112
  # Step 1: Clone the repository if not already done
113
  if not os.path.exists(model_repo_dir):
 
114
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
115
 
116
+ # Step 2: Change working directory
117
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
118
  if os.path.exists(repo_work_dir):
119
+ os.chdir(repo_work_dir)
 
 
120
  else:
121
  print(f"Directory {repo_work_dir} does not exist.")
122
  return
123
+
124
+ # Dynamically load the necessary modules
125
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
126
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
127
  inference_path = os.path.join(os.getcwd(), 'inference.py')
128
 
 
129
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
 
 
130
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
 
 
131
  inference = load_module_from_path("inference", inference_path)
132
 
 
133
  device = 'cpu'
 
134
  model = lwm_model.LWM.from_pretrained(device=device)
135
 
 
136
  with h5py.File(uploaded_file.name, 'r') as f:
137
+ channels = np.array(f['channels'])
138
+ labels = np.array(f['labels'])
 
139
 
 
140
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
141
 
 
142
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
143
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
144
 
145
+ return output_emb, output_raw, labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  except Exception as e:
148
  return str(e), str(e), capture.get_output()
149
 
150
  finally:
 
151
  os.chdir(original_dir)
152
+ sys.stdout = sys.__stdout__
153
 
154
  # Function to handle logic based on whether a file is uploaded or not
155
  def los_nlos_classification(file, percentage_idx):
 
158
  else:
159
  return display_predefined_images(percentage_idx), None
160
 
161
+ # Define the Gradio interface with thinner sliders
162
  with gr.Blocks(css="""
 
 
 
 
 
 
163
  .slider-container {
 
 
164
  text-align: center;
165
+ margin-bottom: 20px;
166
+ }
167
+ .image-row {
168
+ justify-content: center;
169
+ margin-top: 10px;
170
+ }
171
+ input[type=range] {
172
+ width: 180px;
173
+ height: 8px;
174
  }
175
  """) as demo:
176
 
 
193
  with gr.Row():
194
  with gr.Column(elem_id="slider-container"):
195
  gr.Markdown("Percentage of Data for Training")
196
+ percentage_slider_bp = gr.Slider(minimum=0, maximum=9, step=1, value=0, label="Training Data (%)", interactive=True)
197
+ with gr.Column(elem_id="slider-container"):
198
+ gr.Markdown("Task Complexity")
199
+ complexity_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Task Complexity", interactive=True)
 
200
 
201
+ with gr.Row(elem_id="image-row"):
202
+ raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
203
+ embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300)
204
+
205
+ percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
206
+ complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
207
 
208
  with gr.Tab("LoS/NLoS Classification Task"):
209
  gr.Markdown("### LoS/NLoS Classification Task")
 
213
  with gr.Row():
214
  with gr.Column(elem_id="slider-container"):
215
  gr.Markdown("Percentage of Data for Training")
216
+ percentage_slider_los = gr.Slider(minimum=0, maximum=9, step=1, value=0, label="Training Data (%)", interactive=True)
217
 
218
+ with gr.Row(elem_id="image-row"):
219
+ raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
220
+ embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
221
+ output_textbox = gr.Textbox(label="Console Output", lines=8, elem_classes="output-box")
222
 
223
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
224
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])