Sadjad Alikhani commited on
Commit
ef6f553
·
verified ·
1 Parent(s): c061e6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -170
app.py CHANGED
@@ -50,72 +50,50 @@ def display_predefined_images(percentage_idx):
50
 
51
  return raw_image, embeddings_image
52
 
53
- # Updated los_nlos_classification to handle missing outputs properly
54
- def los_nlos_classification(file, percentage_idx):
55
- if file is not None:
56
- raw_cm_image, emb_cm_image, console_output = process_hdf5_file(file, percentage_idx)
57
- return raw_cm_image, emb_cm_image, console_output
58
- else:
59
- raw_image, embeddings_image = display_predefined_images(percentage_idx)
60
- return raw_image, embeddings_image, "No file uploaded. Displaying predefined images."
61
-
62
  # Function to create random images for LoS/NLoS classification results
63
  def create_random_image(size=(300, 300)):
64
  random_image = np.random.rand(*size, 3) * 255
65
  return Image.fromarray(random_image.astype('uint8'))
66
 
67
- # Function to load the pre-trained model from your cloned repository
68
- def load_custom_model():
69
- from lwm_model import LWM # Assuming the model is defined in lwm_model.py
70
- model = LWM() # Modify this according to your model initialization
71
- model.eval()
72
- return model
73
-
74
- import importlib.util
75
-
76
- # Function to dynamically load a Python module from a given file path
77
- def load_module_from_path(module_name, file_path):
78
- spec = importlib.util.spec_from_file_location(module_name, file_path)
79
- module = importlib.util.module_from_spec(spec)
80
- spec.loader.exec_module(module)
81
- return module
82
-
83
  # Function to split dataset into training and test sets based on user selection
84
- def split_dataset(channels, labels, percentage_idx):
85
- percentage = percentage_values[percentage_idx] / 100
86
- num_samples = channels.shape[0]
87
- train_size = int(num_samples * percentage)
88
- print(f'Number of Training Samples: {train_size}')
89
 
90
- indices = np.arange(num_samples)
91
- np.random.shuffle(indices)
92
 
93
- train_idx, test_idx = indices[:train_size], indices[train_size:]
 
 
94
 
95
- train_data, test_data = channels[train_idx], channels[test_idx]
96
- train_labels, test_labels = labels[train_idx], labels[test_idx]
 
97
 
98
- return train_data, test_data, train_labels, test_labels
 
 
 
 
 
99
 
100
- # Function to calculate Euclidean distance between a point and a centroid
101
- def euclidean_distance(x, centroid):
102
- return np.linalg.norm(x - centroid)
103
 
104
- import torch
105
 
 
106
  def classify_based_on_distance(train_data, train_labels, test_data):
107
- # Compute the centroids for the two classes
108
- centroid_0 = train_data[train_labels == 0].mean(dim=0) # Use torch.mean
109
- centroid_1 = train_data[train_labels == 1].mean(dim=0) # Use torch.mean
110
 
111
  predictions = []
112
  for test_point in test_data:
113
- # Compute Euclidean distance between the test point and each centroid
114
- dist_0 = euclidean_distance(test_point, centroid_0)
115
- dist_1 = euclidean_distance(test_point, centroid_1)
116
  predictions.append(0 if dist_0 < dist_1 else 1)
117
 
118
- return torch.tensor(predictions) # Return predictions as a PyTorch tensor
119
 
120
  # Function to generate confusion matrix plot
121
  def plot_confusion_matrix(y_true, y_pred, title):
@@ -132,121 +110,75 @@ def plot_confusion_matrix(y_true, y_pred, title):
132
  plt.savefig(f"{title}.png")
133
  return Image.open(f"{title}.png")
134
 
135
- def identical_train_test_split(output_emb, output_raw, labels, percentage_idx):
136
- N = output_emb.shape[0] # Get the total number of samples
137
-
138
- # Generate the indices for shuffling and splitting
139
- indices = torch.randperm(N) # Randomly shuffle the indices
140
-
141
- # Calculate the split index
142
- split_index = int(N * percentage_values[percentage_idx])
143
- print(f'Training Size: {split_index}')
144
-
145
- # Split indices into train and test
146
- train_indices = indices[:split_index] # First 80% for training
147
- test_indices = indices[split_index:] # Remaining 20% for testing
148
-
149
- # Select the same indices from both output_emb and output_raw
150
- train_emb = output_emb[train_indices]
151
- test_emb = output_emb[test_indices]
152
-
153
- train_raw = output_raw[train_indices]
154
- test_raw = output_raw[test_indices]
155
-
156
- train_labels = labels[train_indices]
157
- test_labels = labels[test_indices]
158
-
159
- return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
160
-
161
- # Store the original working directory when the app starts
162
- original_dir = os.getcwd()
163
-
164
- def process_hdf5_file(uploaded_file, percentage_idx):
165
  capture = PrintCapture()
166
  sys.stdout = capture # Redirect print statements to capture
167
-
168
  try:
 
 
 
 
 
 
 
169
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
170
  model_repo_dir = "./LWM"
171
 
172
- # Step 1: Clone the repository if not already done
173
  if not os.path.exists(model_repo_dir):
174
  print(f"Cloning model repository from {model_repo_url}...")
175
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
176
 
177
- # Step 2: Verify the repository was cloned and change the working directory
178
- repo_work_dir = os.path.join(original_dir, model_repo_dir)
179
- if os.path.exists(repo_work_dir):
180
- os.chdir(repo_work_dir) # Change the working directory only once
181
- print(f"Changed working directory to {os.getcwd()}")
182
- print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
183
- else:
184
- print(f"Directory {repo_work_dir} does not exist.")
185
- return
186
-
187
- # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
188
- lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
189
- input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
190
- inference_path = os.path.join(os.getcwd(), 'inference.py')
191
-
192
- # Load lwm_model
193
- lwm_model = load_module_from_path("lwm_model", lwm_model_path)
194
 
195
- # Load input_preprocess
 
196
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
197
-
198
- # Load inference
199
  inference = load_module_from_path("inference", inference_path)
200
 
201
- # Step 4: Load the model from lwm_model module
202
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
- print(f"Loading the LWM model on {device}...")
204
  model = lwm_model.LWM.from_pretrained(device=device).to(torch.float32)
205
 
206
- # Step 5: Load the HDF5 file and extract the channels and labels
207
- with h5py.File(uploaded_file.name, 'r') as f:
208
- channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
209
- labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
210
- print(f"Loaded dataset with {channels.shape[0]} samples.")
211
-
212
- # Step 7: Tokenize the data using the tokenizer from input_preprocess
213
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
214
- print(preprocessed_chs[0][0][1])
215
-
216
- # Step 7: Perform inference using the functions from inference.py
217
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
218
- #print(f'output_emb:{output_emb[10][0]}')
219
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
220
- #print(f'output_raw:{output_raw[10][0]}')
221
 
222
  print(f"Output Embeddings Shape: {output_emb.shape}")
223
  print(f"Output Raw Shape: {output_raw.shape}")
224
 
225
- train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
226
- output_raw.view(len(output_raw),-1),
227
- labels,
228
- percentage_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- # Step 8: Perform classification using the Euclidean distance for both raw and embeddings
231
- print(f'train_data_emb: {train_data_emb.shape}')
232
- print(f'train_labels: {train_labels.shape}')
233
- print(f'test_data_emb: {test_data_emb.shape}')
234
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
235
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
236
- print(f'pred_emb: {pred_emb}')
237
- # Step 9: Generate confusion matrices for both raw and embeddings
238
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
239
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
240
 
241
- return raw_cm_image, emb_cm_image, capture.get_output()
242
-
243
- except Exception as e:
244
- return str(e), str(e), capture.get_output()
245
-
246
- finally:
247
- # Always return to the original working directory after processing
248
- os.chdir(original_dir)
249
- sys.stdout = sys.__stdout__ # Reset print statements
250
 
251
  # Define the Gradio interface
252
  with gr.Blocks(css="""
@@ -262,56 +194,37 @@ with gr.Blocks(css="""
262
  text-align: center;
263
  }
264
  """) as demo:
265
-
266
- # Contact Section
267
- gr.Markdown("""
268
- <div style="text-align: center;">
269
- <a target="_blank" href="https://www.wi-lab.net">
270
- <img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;">
271
- </a>
272
- <a target="_blank" href="mailto:alikhani@asu.edu" style="margin-left: 10px;">
273
- <img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail" alt="Email">
274
- </a>
275
- </div>
276
- """)
277
-
278
- # Tabs for Beam Prediction and LoS/NLoS Classification
279
- with gr.Tab("Beam Prediction Task"):
280
- gr.Markdown("### Beam Prediction Task")
281
-
282
- with gr.Row():
283
- with gr.Column(elem_id="slider-container"):
284
- gr.Markdown("Percentage of Data for Training")
285
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
286
-
287
- with gr.Row():
288
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
289
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
290
-
291
- percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
292
 
 
293
  with gr.Tab("LoS/NLoS Classification Task"):
294
  gr.Markdown("### LoS/NLoS Classification Task")
295
-
296
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
297
 
298
  with gr.Row():
299
- with gr.Column(elem_id="slider-container"):
300
- gr.Markdown("Percentage of Data for Training")
301
- #percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
302
- percentage_dropdown_los = gr.Dropdown(choices=[str(v) for v in percentage_values*10],
303
- value=10,
304
- label="Percentage of Data for Training",
305
- interactive=True)
306
 
307
  with gr.Row():
308
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
309
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
310
  output_textbox = gr.Textbox(label="Console Output", lines=10)
311
 
312
- file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
313
- percentage_dropdown_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_dropdown_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
 
 
 
 
 
 
 
314
 
315
  # Launch the app
316
  if __name__ == "__main__":
317
- demo.launch()
 
50
 
51
  return raw_image, embeddings_image
52
 
 
 
 
 
 
 
 
 
 
53
  # Function to create random images for LoS/NLoS classification results
54
  def create_random_image(size=(300, 300)):
55
  random_image = np.random.rand(*size, 3) * 255
56
  return Image.fromarray(random_image.astype('uint8'))
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # Function to split dataset into training and test sets based on user selection
59
+ def identical_train_test_split(output_emb, output_raw, labels, percentage_idx):
60
+ N = output_emb.shape[0] # Get the total number of samples
 
 
 
61
 
62
+ # Generate the indices for shuffling and splitting
63
+ indices = torch.randperm(N) # Randomly shuffle the indices
64
 
65
+ # Calculate the split index
66
+ split_index = int(N * percentage_values[percentage_idx] / 10) # Convert percentage index to percentage value
67
+ print(f'Training Size: {split_index}')
68
 
69
+ # Split indices into train and test
70
+ train_indices = indices[:split_index]
71
+ test_indices = indices[split_index:]
72
 
73
+ # Select the same indices from both output_emb and output_raw
74
+ train_emb = output_emb[train_indices]
75
+ test_emb = output_emb[test_indices]
76
+
77
+ train_raw = output_raw[train_indices]
78
+ test_raw = output_raw[test_indices]
79
 
80
+ train_labels = labels[train_indices]
81
+ test_labels = labels[test_indices]
 
82
 
83
+ return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
84
 
85
+ # Function to calculate Euclidean distance between a point and a centroid
86
  def classify_based_on_distance(train_data, train_labels, test_data):
87
+ centroid_0 = train_data[train_labels == 0].mean(dim=0)
88
+ centroid_1 = train_data[train_labels == 1].mean(dim=0)
 
89
 
90
  predictions = []
91
  for test_point in test_data:
92
+ dist_0 = torch.norm(test_point - centroid_0)
93
+ dist_1 = torch.norm(test_point - centroid_1)
 
94
  predictions.append(0 if dist_0 < dist_1 else 1)
95
 
96
+ return torch.tensor(predictions)
97
 
98
  # Function to generate confusion matrix plot
99
  def plot_confusion_matrix(y_true, y_pred, title):
 
110
  plt.savefig(f"{title}.png")
111
  return Image.open(f"{title}.png")
112
 
113
+ # Function to handle inference and return the results (store the results to state)
114
+ def run_inference(uploaded_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  capture = PrintCapture()
116
  sys.stdout = capture # Redirect print statements to capture
117
+
118
  try:
119
+ # Load the HDF5 file and extract channels and labels
120
+ with h5py.File(uploaded_file.name, 'r') as f:
121
+ channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
122
+ labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
123
+ print(f"Loaded dataset with {channels.shape[0]} samples.")
124
+
125
+ # Run the tokenization and model inference
126
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
127
  model_repo_dir = "./LWM"
128
 
 
129
  if not os.path.exists(model_repo_dir):
130
  print(f"Cloning model repository from {model_repo_url}...")
131
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
132
 
133
+ # Load the model
134
+ lwm_model_path = os.path.join(model_repo_dir, 'lwm_model.py')
135
+ input_preprocess_path = os.path.join(model_repo_dir, 'input_preprocess.py')
136
+ inference_path = os.path.join(model_repo_dir, 'inference.py')
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Load dynamically
139
+ lwm_model = load_module_from_path("lwm_model", lwm_model_path)
140
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
 
 
141
  inference = load_module_from_path("inference", inference_path)
142
 
 
143
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
144
+ print(f"Loading LWM model on {device}...")
145
  model = lwm_model.LWM.from_pretrained(device=device).to(torch.float32)
146
 
147
+ # Preprocess and inference
 
 
 
 
 
 
148
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
 
 
 
149
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
 
150
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
 
151
 
152
  print(f"Output Embeddings Shape: {output_emb.shape}")
153
  print(f"Output Raw Shape: {output_raw.shape}")
154
 
155
+ return output_emb, output_raw, labels, capture.get_output()
156
+
157
+ except Exception as e:
158
+ return None, None, None, str(e)
159
+
160
+ finally:
161
+ sys.stdout = sys.__stdout__ # Reset print statements
162
+
163
+ # Function to handle classification after inference (using Gradio state)
164
+ def los_nlos_classification(output_emb, output_raw, labels, percentage_idx):
165
+ if output_emb is not None and output_raw is not None:
166
+ train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(
167
+ output_emb.view(len(output_emb), -1),
168
+ output_raw.view(len(output_raw), -1),
169
+ labels,
170
+ percentage_idx
171
+ )
172
 
 
 
 
 
173
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
174
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
175
+
 
176
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
177
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
178
 
179
+ return raw_cm_image, emb_cm_image, "Classification successful"
180
+
181
+ return create_random_image(), create_random_image(), "No valid inference outputs"
 
 
 
 
 
 
182
 
183
  # Define the Gradio interface
184
  with gr.Blocks(css="""
 
194
  text-align: center;
195
  }
196
  """) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ # Tabs for Beam Prediction and LoS/NLoS Classification
199
  with gr.Tab("LoS/NLoS Classification Task"):
200
  gr.Markdown("### LoS/NLoS Classification Task")
201
+
202
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
203
 
204
  with gr.Row():
205
+ percentage_dropdown_los = gr.Dropdown(
206
+ choices=[str(v) for v in percentage_values * 10],
207
+ value=10,
208
+ label="Percentage of Data for Training",
209
+ interactive=True
210
+ )
 
211
 
212
  with gr.Row():
213
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
214
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
215
  output_textbox = gr.Textbox(label="Console Output", lines=10)
216
 
217
+ # Process file upload to run inference
218
+ inference_output = gr.State()
219
+ file_input.upload(run_inference, inputs=file_input, outputs=inference_output)
220
+
221
+ # Handle dropdown change for classification
222
+ percentage_dropdown_los.change(
223
+ fn=los_nlos_classification,
224
+ inputs=[inference_output['output_emb'], inference_output['output_raw'], inference_output['labels'], percentage_dropdown_los],
225
+ outputs=[raw_img_los, embeddings_img_los, output_textbox]
226
+ )
227
 
228
  # Launch the app
229
  if __name__ == "__main__":
230
+ demo.launch()