Sadjad Alikhani commited on
Commit
bc8a6bd
·
verified ·
1 Parent(s): 6c2d844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -122
app.py CHANGED
@@ -47,15 +47,6 @@ def create_random_image(size=(300, 300)):
47
  random_image = np.random.rand(*size, 3) * 255
48
  return Image.fromarray(random_image.astype('uint8'))
49
 
50
- # Function to load the pre-trained model from your cloned repository
51
- def load_custom_model():
52
- from lwm_model import LWM # Assuming the model is defined in lwm_model.py
53
- model = LWM() # Modify this according to your model initialization
54
- model.eval()
55
- return model
56
-
57
- import importlib.util
58
-
59
  # Function to dynamically load a Python module from a given file path
60
  def load_module_from_path(module_name, file_path):
61
  spec = importlib.util.spec_from_file_location(module_name, file_path)
@@ -68,42 +59,19 @@ def split_dataset(channels, labels, percentage_idx):
68
  percentage = percentage_values[percentage_idx] / 100
69
  num_samples = channels.shape[0]
70
  train_size = int(num_samples * percentage)
71
- print(f'Number of Training Samples: {train_size}')
72
-
73
  indices = np.arange(num_samples)
74
  np.random.shuffle(indices)
75
 
76
  train_idx, test_idx = indices[:train_size], indices[train_size:]
77
-
78
  train_data, test_data = channels[train_idx], channels[test_idx]
79
  train_labels, test_labels = labels[train_idx], labels[test_idx]
80
 
81
  return train_data, test_data, train_labels, test_labels
82
 
83
- # Function to calculate Euclidean distance between a point and a centroid
84
- def euclidean_distance(x, centroid):
85
- return np.linalg.norm(x - centroid)
86
-
87
- import torch
88
-
89
- def classify_based_on_distance(train_data, train_labels, test_data):
90
- # Compute the centroids for the two classes
91
- centroid_0 = train_data[train_labels == 0].mean(dim=0) # Use torch.mean
92
- centroid_1 = train_data[train_labels == 1].mean(dim=0) # Use torch.mean
93
-
94
- predictions = []
95
- for test_point in test_data:
96
- # Compute Euclidean distance between the test point and each centroid
97
- dist_0 = euclidean_distance(test_point, centroid_0)
98
- dist_1 = euclidean_distance(test_point, centroid_1)
99
- predictions.append(0 if dist_0 < dist_1 else 1)
100
-
101
- return torch.tensor(predictions) # Return predictions as a PyTorch tensor
102
-
103
  # Function to generate confusion matrix plot
104
  def plot_confusion_matrix(y_true, y_pred, title):
105
  cm = confusion_matrix(y_true, y_pred)
106
- plt.figure(figsize=(5, 5))
107
  plt.imshow(cm, cmap='Blues')
108
  plt.title(title)
109
  plt.xlabel('Predicted')
@@ -116,101 +84,76 @@ def plot_confusion_matrix(y_true, y_pred, title):
116
  return Image.open(f"{title}.png")
117
 
118
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
119
- N = output_emb.shape[0] # Get the total number of samples
120
-
121
- # Generate the indices for shuffling and splitting
122
- indices = torch.randperm(N) # Randomly shuffle the indices
123
-
124
- # Calculate the split index
125
  split_index = int(N * percentage)
126
-
127
- # Split indices into train and test
128
- train_indices = indices[:split_index] # First 80% for training
129
- test_indices = indices[split_index:] # Remaining 20% for testing
130
-
131
- # Select the same indices from both output_emb and output_raw
132
  train_emb = output_emb[train_indices]
133
  test_emb = output_emb[test_indices]
134
-
135
  train_raw = output_raw[train_indices]
136
  test_raw = output_raw[test_indices]
137
-
138
  train_labels = labels[train_indices]
139
  test_labels = labels[test_indices]
140
 
141
  return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Store the original working directory when the app starts
144
  original_dir = os.getcwd()
145
 
146
  def process_hdf5_file(uploaded_file, percentage_idx):
147
  capture = PrintCapture()
148
- sys.stdout = capture # Redirect print statements to capture
149
 
150
  try:
151
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
152
  model_repo_dir = "./LWM"
153
-
154
- # Step 1: Clone the repository if not already done
155
  if not os.path.exists(model_repo_dir):
156
- print(f"Cloning model repository from {model_repo_url}...")
157
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
158
-
159
- # Step 2: Verify the repository was cloned and change the working directory
160
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
161
  if os.path.exists(repo_work_dir):
162
- os.chdir(repo_work_dir) # Change the working directory only once
163
- print(f"Changed working directory to {os.getcwd()}")
164
- print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
165
- else:
166
- print(f"Directory {repo_work_dir} does not exist.")
167
- return
168
-
169
- # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
170
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
171
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
172
  inference_path = os.path.join(os.getcwd(), 'inference.py')
173
 
174
- # Load lwm_model
175
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
176
-
177
- # Load input_preprocess
178
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
179
-
180
- # Load inference
181
  inference = load_module_from_path("inference", inference_path)
182
 
183
- # Step 4: Load the model from lwm_model module
184
  device = 'cpu'
185
- print(f"Loading the LWM model on {device}...")
186
  model = lwm_model.LWM.from_pretrained(device=device)
187
 
188
- # Step 5: Load the HDF5 file and extract the channels and labels
189
  with h5py.File(uploaded_file.name, 'r') as f:
190
- channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
191
- labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
192
- print(f"Loaded dataset with {channels.shape[0]} samples.")
193
 
194
- # Step 7: Tokenize the data using the tokenizer from input_preprocess
195
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
196
-
197
- # Step 7: Perform inference using the functions from inference.py
198
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
199
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
200
 
201
- print(f"Output Embeddings Shape: {output_emb.shape}")
202
- print(f"Output Raw Shape: {output_raw.shape}")
203
-
204
- train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
205
- output_raw.view(len(output_raw),-1),
206
- labels,
207
- percentage_idx)
208
 
209
- # Step 8: Perform classification using the Euclidean distance for both raw and embeddings
210
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
211
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
212
 
213
- # Step 9: Generate confusion matrices for both raw and embeddings
214
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
215
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
216
 
@@ -220,73 +163,64 @@ def process_hdf5_file(uploaded_file, percentage_idx):
220
  return str(e), str(e), capture.get_output()
221
 
222
  finally:
223
- # Always return to the original working directory after processing
224
  os.chdir(original_dir)
225
- sys.stdout = sys.__stdout__ # Reset print statements
226
 
227
- # Function to handle logic based on whether a file is uploaded or not
228
  def los_nlos_classification(file, percentage_idx):
229
  if file is not None:
230
  return process_hdf5_file(file, percentage_idx)
231
  else:
232
  return display_predefined_images(percentage_idx), None
233
 
234
- # Define the Gradio interface
235
  with gr.Blocks(css="""
236
- .vertical-slider input[type=range] {
237
- writing-mode: bt-lr; /* IE */
238
- -webkit-appearance: slider-vertical; /* WebKit */
239
- width: 8px;
240
- height: 200px;
241
- }
242
  .slider-container {
243
- display: inline-block;
244
- margin-right: 50px;
245
  text-align: center;
 
 
 
 
 
 
 
 
 
246
  }
247
  """) as demo:
248
 
249
  # Contact Section
250
- gr.Markdown(
251
- """
252
- ## Contact
253
- <div style="display: flex; align-items: center;">
254
- <a target="_blank" href="https://www.wi-lab.net"><img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;"></a>&nbsp;&nbsp;
255
- <a target="_blank" href="mailto:alikhani@asu.edu"><img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail " alt="Email"></a>&nbsp;&nbsp;
 
 
256
  </div>
257
- """
258
- )
259
 
260
  # Tabs for Beam Prediction and LoS/NLoS Classification
261
  with gr.Tab("Beam Prediction Task"):
262
  gr.Markdown("### Beam Prediction Task")
263
-
264
  with gr.Row():
265
  with gr.Column(elem_id="slider-container"):
266
- gr.Markdown("Percentage of Data for Training")
267
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
268
-
269
- with gr.Row():
270
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
271
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
272
-
273
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
274
 
275
  with gr.Tab("LoS/NLoS Classification Task"):
276
  gr.Markdown("### LoS/NLoS Classification Task")
277
-
278
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
279
-
280
  with gr.Row():
281
  with gr.Column(elem_id="slider-container"):
282
- gr.Markdown("Percentage of Data for Training")
283
- percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
284
-
285
- with gr.Row():
286
- raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
287
- embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
288
- output_textbox = gr.Textbox(label="Console Output", lines=10)
289
-
290
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
291
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
292
 
 
47
  random_image = np.random.rand(*size, 3) * 255
48
  return Image.fromarray(random_image.astype('uint8'))
49
 
 
 
 
 
 
 
 
 
 
50
  # Function to dynamically load a Python module from a given file path
51
  def load_module_from_path(module_name, file_path):
52
  spec = importlib.util.spec_from_file_location(module_name, file_path)
 
59
  percentage = percentage_values[percentage_idx] / 100
60
  num_samples = channels.shape[0]
61
  train_size = int(num_samples * percentage)
 
 
62
  indices = np.arange(num_samples)
63
  np.random.shuffle(indices)
64
 
65
  train_idx, test_idx = indices[:train_size], indices[train_size:]
 
66
  train_data, test_data = channels[train_idx], channels[test_idx]
67
  train_labels, test_labels = labels[train_idx], labels[test_idx]
68
 
69
  return train_data, test_data, train_labels, test_labels
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Function to generate confusion matrix plot
72
  def plot_confusion_matrix(y_true, y_pred, title):
73
  cm = confusion_matrix(y_true, y_pred)
74
+ plt.figure(figsize=(4, 4))
75
  plt.imshow(cm, cmap='Blues')
76
  plt.title(title)
77
  plt.xlabel('Predicted')
 
84
  return Image.open(f"{title}.png")
85
 
86
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
87
+ N = output_emb.shape[0]
88
+ indices = torch.randperm(N)
 
 
 
 
89
  split_index = int(N * percentage)
90
+ train_indices = indices[:split_index]
91
+ test_indices = indices[split_index:]
 
 
 
 
92
  train_emb = output_emb[train_indices]
93
  test_emb = output_emb[test_indices]
 
94
  train_raw = output_raw[train_indices]
95
  test_raw = output_raw[test_indices]
 
96
  train_labels = labels[train_indices]
97
  test_labels = labels[test_indices]
98
 
99
  return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
100
 
101
+ # Function to classify test data based on distance to class centroids
102
+ def classify_based_on_distance(train_data, train_labels, test_data):
103
+ centroid_0 = train_data[train_labels == 0].mean(dim=0)
104
+ centroid_1 = train_data[train_labels == 1].mean(dim=0)
105
+
106
+ predictions = []
107
+ for test_point in test_data:
108
+ dist_0 = torch.norm(test_point - centroid_0)
109
+ dist_1 = torch.norm(test_point - centroid_1)
110
+ predictions.append(0 if dist_0 < dist_1 else 1)
111
+
112
+ return torch.tensor(predictions)
113
+
114
  # Store the original working directory when the app starts
115
  original_dir = os.getcwd()
116
 
117
  def process_hdf5_file(uploaded_file, percentage_idx):
118
  capture = PrintCapture()
119
+ sys.stdout = capture
120
 
121
  try:
122
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
123
  model_repo_dir = "./LWM"
 
 
124
  if not os.path.exists(model_repo_dir):
 
125
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
 
 
126
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
127
  if os.path.exists(repo_work_dir):
128
+ os.chdir(repo_work_dir)
 
 
 
 
 
 
 
129
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
130
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
131
  inference_path = os.path.join(os.getcwd(), 'inference.py')
132
 
 
133
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
 
 
134
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
 
 
135
  inference = load_module_from_path("inference", inference_path)
136
 
 
137
  device = 'cpu'
 
138
  model = lwm_model.LWM.from_pretrained(device=device)
139
 
 
140
  with h5py.File(uploaded_file.name, 'r') as f:
141
+ channels = np.array(f['channels'])
142
+ labels = np.array(f['labels'])
 
143
 
 
144
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
 
 
145
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
146
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
147
 
148
+ train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(
149
+ output_emb.view(len(output_emb),-1),
150
+ output_raw.view(len(output_raw),-1),
151
+ labels,
152
+ percentage_idx)
 
 
153
 
 
154
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
155
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
156
 
 
157
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
158
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
159
 
 
163
  return str(e), str(e), capture.get_output()
164
 
165
  finally:
 
166
  os.chdir(original_dir)
167
+ sys.stdout = sys.__stdout__
168
 
 
169
  def los_nlos_classification(file, percentage_idx):
170
  if file is not None:
171
  return process_hdf5_file(file, percentage_idx)
172
  else:
173
  return display_predefined_images(percentage_idx), None
174
 
175
+ # Define the Gradio interface with a compact, minimal layout
176
  with gr.Blocks(css="""
 
 
 
 
 
 
177
  .slider-container {
 
 
178
  text-align: center;
179
+ margin-bottom: 20px;
180
+ }
181
+ .image-row {
182
+ justify-content: center;
183
+ margin-top: 10px;
184
+ }
185
+ .output-box {
186
+ max-width: 600px;
187
+ margin: 0 auto;
188
  }
189
  """) as demo:
190
 
191
  # Contact Section
192
+ gr.Markdown("""
193
+ <div style="text-align: center;">
194
+ <a target="_blank" href="https://www.wi-lab.net">
195
+ <img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;">
196
+ </a>
197
+ <a target="_blank" href="mailto:alikhani@asu.edu" style="margin-left: 10px;">
198
+ <img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail" alt="Email">
199
+ </a>
200
  </div>
201
+ """)
 
202
 
203
  # Tabs for Beam Prediction and LoS/NLoS Classification
204
  with gr.Tab("Beam Prediction Task"):
205
  gr.Markdown("### Beam Prediction Task")
 
206
  with gr.Row():
207
  with gr.Column(elem_id="slider-container"):
208
+ percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Training Data (%)")
209
+ with gr.Row(elem_id="image-row"):
210
+ raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
211
+ embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300)
 
 
 
212
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
213
 
214
  with gr.Tab("LoS/NLoS Classification Task"):
215
  gr.Markdown("### LoS/NLoS Classification Task")
 
216
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
 
217
  with gr.Row():
218
  with gr.Column(elem_id="slider-container"):
219
+ percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Training Data (%)")
220
+ with gr.Row(elem_id="image-row"):
221
+ raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
222
+ embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
223
+ output_textbox = gr.Textbox(label="Console Output", lines=8, elem_classes="output-box")
 
 
 
224
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
225
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
226