Sadjad Alikhani commited on
Commit
469d918
·
verified ·
1 Parent(s): bc8a6bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -48
app.py CHANGED
@@ -16,7 +16,7 @@ RAW_PATH = os.path.join("images", "raw")
16
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
17
 
18
  # Specific values for percentage of data for training
19
- percentage_values = [10, 30, 50, 70, 100]
20
 
21
  # Custom class to capture print output
22
  class PrintCapture(io.StringIO):
@@ -47,6 +47,15 @@ def create_random_image(size=(300, 300)):
47
  random_image = np.random.rand(*size, 3) * 255
48
  return Image.fromarray(random_image.astype('uint8'))
49
 
 
 
 
 
 
 
 
 
 
50
  # Function to dynamically load a Python module from a given file path
51
  def load_module_from_path(module_name, file_path):
52
  spec = importlib.util.spec_from_file_location(module_name, file_path)
@@ -59,19 +68,42 @@ def split_dataset(channels, labels, percentage_idx):
59
  percentage = percentage_values[percentage_idx] / 100
60
  num_samples = channels.shape[0]
61
  train_size = int(num_samples * percentage)
 
 
62
  indices = np.arange(num_samples)
63
  np.random.shuffle(indices)
64
 
65
  train_idx, test_idx = indices[:train_size], indices[train_size:]
 
66
  train_data, test_data = channels[train_idx], channels[test_idx]
67
  train_labels, test_labels = labels[train_idx], labels[test_idx]
68
 
69
  return train_data, test_data, train_labels, test_labels
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Function to generate confusion matrix plot
72
  def plot_confusion_matrix(y_true, y_pred, title):
73
  cm = confusion_matrix(y_true, y_pred)
74
- plt.figure(figsize=(4, 4))
75
  plt.imshow(cm, cmap='Blues')
76
  plt.title(title)
77
  plt.xlabel('Predicted')
@@ -84,76 +116,101 @@ def plot_confusion_matrix(y_true, y_pred, title):
84
  return Image.open(f"{title}.png")
85
 
86
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
87
- N = output_emb.shape[0]
88
- indices = torch.randperm(N)
 
 
 
 
89
  split_index = int(N * percentage)
90
- train_indices = indices[:split_index]
91
- test_indices = indices[split_index:]
 
 
 
 
92
  train_emb = output_emb[train_indices]
93
  test_emb = output_emb[test_indices]
 
94
  train_raw = output_raw[train_indices]
95
  test_raw = output_raw[test_indices]
 
96
  train_labels = labels[train_indices]
97
  test_labels = labels[test_indices]
98
 
99
  return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
100
 
101
- # Function to classify test data based on distance to class centroids
102
- def classify_based_on_distance(train_data, train_labels, test_data):
103
- centroid_0 = train_data[train_labels == 0].mean(dim=0)
104
- centroid_1 = train_data[train_labels == 1].mean(dim=0)
105
-
106
- predictions = []
107
- for test_point in test_data:
108
- dist_0 = torch.norm(test_point - centroid_0)
109
- dist_1 = torch.norm(test_point - centroid_1)
110
- predictions.append(0 if dist_0 < dist_1 else 1)
111
-
112
- return torch.tensor(predictions)
113
-
114
  # Store the original working directory when the app starts
115
  original_dir = os.getcwd()
116
 
117
  def process_hdf5_file(uploaded_file, percentage_idx):
118
  capture = PrintCapture()
119
- sys.stdout = capture
120
 
121
  try:
122
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
123
  model_repo_dir = "./LWM"
 
 
124
  if not os.path.exists(model_repo_dir):
 
125
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
 
 
126
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
127
  if os.path.exists(repo_work_dir):
128
- os.chdir(repo_work_dir)
 
 
 
 
 
 
 
129
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
130
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
131
  inference_path = os.path.join(os.getcwd(), 'inference.py')
132
 
 
133
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
 
 
134
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
 
 
135
  inference = load_module_from_path("inference", inference_path)
136
 
 
137
  device = 'cpu'
 
138
  model = lwm_model.LWM.from_pretrained(device=device)
139
 
 
140
  with h5py.File(uploaded_file.name, 'r') as f:
141
- channels = np.array(f['channels'])
142
- labels = np.array(f['labels'])
 
143
 
 
144
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
 
 
145
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
146
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
147
 
148
- train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(
149
- output_emb.view(len(output_emb),-1),
150
- output_raw.view(len(output_raw),-1),
151
- labels,
152
- percentage_idx)
 
 
153
 
 
154
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
155
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
156
 
 
157
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
158
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
159
 
@@ -163,28 +220,29 @@ def process_hdf5_file(uploaded_file, percentage_idx):
163
  return str(e), str(e), capture.get_output()
164
 
165
  finally:
 
166
  os.chdir(original_dir)
167
- sys.stdout = sys.__stdout__
168
 
 
169
  def los_nlos_classification(file, percentage_idx):
170
  if file is not None:
171
  return process_hdf5_file(file, percentage_idx)
172
  else:
173
  return display_predefined_images(percentage_idx), None
174
 
175
- # Define the Gradio interface with a compact, minimal layout
176
  with gr.Blocks(css="""
 
 
 
 
 
 
177
  .slider-container {
 
 
178
  text-align: center;
179
- margin-bottom: 20px;
180
- }
181
- .image-row {
182
- justify-content: center;
183
- margin-top: 10px;
184
- }
185
- .output-box {
186
- max-width: 600px;
187
- margin: 0 auto;
188
  }
189
  """) as demo:
190
 
@@ -203,24 +261,33 @@ with gr.Blocks(css="""
203
  # Tabs for Beam Prediction and LoS/NLoS Classification
204
  with gr.Tab("Beam Prediction Task"):
205
  gr.Markdown("### Beam Prediction Task")
 
206
  with gr.Row():
207
  with gr.Column(elem_id="slider-container"):
208
- percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Training Data (%)")
209
- with gr.Row(elem_id="image-row"):
210
- raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
211
- embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300)
 
 
 
212
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
213
 
214
  with gr.Tab("LoS/NLoS Classification Task"):
215
  gr.Markdown("### LoS/NLoS Classification Task")
 
216
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
 
217
  with gr.Row():
218
  with gr.Column(elem_id="slider-container"):
219
- percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, label="Training Data (%)")
220
- with gr.Row(elem_id="image-row"):
221
- raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300)
222
- embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300)
223
- output_textbox = gr.Textbox(label="Console Output", lines=8, elem_classes="output-box")
 
 
 
224
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
225
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
226
 
 
16
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
17
 
18
  # Specific values for percentage of data for training
19
+ percentage_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
20
 
21
  # Custom class to capture print output
22
  class PrintCapture(io.StringIO):
 
47
  random_image = np.random.rand(*size, 3) * 255
48
  return Image.fromarray(random_image.astype('uint8'))
49
 
50
+ # Function to load the pre-trained model from your cloned repository
51
+ def load_custom_model():
52
+ from lwm_model import LWM # Assuming the model is defined in lwm_model.py
53
+ model = LWM() # Modify this according to your model initialization
54
+ model.eval()
55
+ return model
56
+
57
+ import importlib.util
58
+
59
  # Function to dynamically load a Python module from a given file path
60
  def load_module_from_path(module_name, file_path):
61
  spec = importlib.util.spec_from_file_location(module_name, file_path)
 
68
  percentage = percentage_values[percentage_idx] / 100
69
  num_samples = channels.shape[0]
70
  train_size = int(num_samples * percentage)
71
+ print(f'Number of Training Samples: {train_size}')
72
+
73
  indices = np.arange(num_samples)
74
  np.random.shuffle(indices)
75
 
76
  train_idx, test_idx = indices[:train_size], indices[train_size:]
77
+
78
  train_data, test_data = channels[train_idx], channels[test_idx]
79
  train_labels, test_labels = labels[train_idx], labels[test_idx]
80
 
81
  return train_data, test_data, train_labels, test_labels
82
 
83
+ # Function to calculate Euclidean distance between a point and a centroid
84
+ def euclidean_distance(x, centroid):
85
+ return np.linalg.norm(x - centroid)
86
+
87
+ import torch
88
+
89
+ def classify_based_on_distance(train_data, train_labels, test_data):
90
+ # Compute the centroids for the two classes
91
+ centroid_0 = train_data[train_labels == 0].mean(dim=0) # Use torch.mean
92
+ centroid_1 = train_data[train_labels == 1].mean(dim=0) # Use torch.mean
93
+
94
+ predictions = []
95
+ for test_point in test_data:
96
+ # Compute Euclidean distance between the test point and each centroid
97
+ dist_0 = euclidean_distance(test_point, centroid_0)
98
+ dist_1 = euclidean_distance(test_point, centroid_1)
99
+ predictions.append(0 if dist_0 < dist_1 else 1)
100
+
101
+ return torch.tensor(predictions) # Return predictions as a PyTorch tensor
102
+
103
  # Function to generate confusion matrix plot
104
  def plot_confusion_matrix(y_true, y_pred, title):
105
  cm = confusion_matrix(y_true, y_pred)
106
+ plt.figure(figsize=(5, 5))
107
  plt.imshow(cm, cmap='Blues')
108
  plt.title(title)
109
  plt.xlabel('Predicted')
 
116
  return Image.open(f"{title}.png")
117
 
118
  def identical_train_test_split(output_emb, output_raw, labels, percentage):
119
+ N = output_emb.shape[0] # Get the total number of samples
120
+
121
+ # Generate the indices for shuffling and splitting
122
+ indices = torch.randperm(N) # Randomly shuffle the indices
123
+
124
+ # Calculate the split index
125
  split_index = int(N * percentage)
126
+
127
+ # Split indices into train and test
128
+ train_indices = indices[:split_index] # First 80% for training
129
+ test_indices = indices[split_index:] # Remaining 20% for testing
130
+
131
+ # Select the same indices from both output_emb and output_raw
132
  train_emb = output_emb[train_indices]
133
  test_emb = output_emb[test_indices]
134
+
135
  train_raw = output_raw[train_indices]
136
  test_raw = output_raw[test_indices]
137
+
138
  train_labels = labels[train_indices]
139
  test_labels = labels[test_indices]
140
 
141
  return train_emb, test_emb, train_raw, test_raw, train_labels, test_labels
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # Store the original working directory when the app starts
144
  original_dir = os.getcwd()
145
 
146
  def process_hdf5_file(uploaded_file, percentage_idx):
147
  capture = PrintCapture()
148
+ sys.stdout = capture # Redirect print statements to capture
149
 
150
  try:
151
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
152
  model_repo_dir = "./LWM"
153
+
154
+ # Step 1: Clone the repository if not already done
155
  if not os.path.exists(model_repo_dir):
156
+ print(f"Cloning model repository from {model_repo_url}...")
157
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
158
+
159
+ # Step 2: Verify the repository was cloned and change the working directory
160
  repo_work_dir = os.path.join(original_dir, model_repo_dir)
161
  if os.path.exists(repo_work_dir):
162
+ os.chdir(repo_work_dir) # Change the working directory only once
163
+ print(f"Changed working directory to {os.getcwd()}")
164
+ print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
165
+ else:
166
+ print(f"Directory {repo_work_dir} does not exist.")
167
+ return
168
+
169
+ # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
170
  lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
171
  input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
172
  inference_path = os.path.join(os.getcwd(), 'inference.py')
173
 
174
+ # Load lwm_model
175
  lwm_model = load_module_from_path("lwm_model", lwm_model_path)
176
+
177
+ # Load input_preprocess
178
  input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
179
+
180
+ # Load inference
181
  inference = load_module_from_path("inference", inference_path)
182
 
183
+ # Step 4: Load the model from lwm_model module
184
  device = 'cpu'
185
+ print(f"Loading the LWM model on {device}...")
186
  model = lwm_model.LWM.from_pretrained(device=device)
187
 
188
+ # Step 5: Load the HDF5 file and extract the channels and labels
189
  with h5py.File(uploaded_file.name, 'r') as f:
190
+ channels = np.array(f['channels']) # Assuming 'channels' dataset in the HDF5 file
191
+ labels = np.array(f['labels']) # Assuming 'labels' dataset in the HDF5 file
192
+ print(f"Loaded dataset with {channels.shape[0]} samples.")
193
 
194
+ # Step 7: Tokenize the data using the tokenizer from input_preprocess
195
  preprocessed_chs = input_preprocess.tokenizer(manual_data=channels)
196
+
197
+ # Step 7: Perform inference using the functions from inference.py
198
  output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
199
  output_raw = inference.create_raw_dataset(preprocessed_chs, device)
200
 
201
+ print(f"Output Embeddings Shape: {output_emb.shape}")
202
+ print(f"Output Raw Shape: {output_raw.shape}")
203
+
204
+ train_data_emb, test_data_emb, train_data_raw, test_data_raw, train_labels, test_labels = identical_train_test_split(output_emb.view(len(output_emb),-1),
205
+ output_raw.view(len(output_raw),-1),
206
+ labels,
207
+ percentage_idx)
208
 
209
+ # Step 8: Perform classification using the Euclidean distance for both raw and embeddings
210
  pred_raw = classify_based_on_distance(train_data_raw, train_labels, test_data_raw)
211
  pred_emb = classify_based_on_distance(train_data_emb, train_labels, test_data_emb)
212
 
213
+ # Step 9: Generate confusion matrices for both raw and embeddings
214
  raw_cm_image = plot_confusion_matrix(test_labels, pred_raw, title="Confusion Matrix (Raw Channels)")
215
  emb_cm_image = plot_confusion_matrix(test_labels, pred_emb, title="Confusion Matrix (Embeddings)")
216
 
 
220
  return str(e), str(e), capture.get_output()
221
 
222
  finally:
223
+ # Always return to the original working directory after processing
224
  os.chdir(original_dir)
225
+ sys.stdout = sys.__stdout__ # Reset print statements
226
 
227
+ # Function to handle logic based on whether a file is uploaded or not
228
  def los_nlos_classification(file, percentage_idx):
229
  if file is not None:
230
  return process_hdf5_file(file, percentage_idx)
231
  else:
232
  return display_predefined_images(percentage_idx), None
233
 
234
+ # Define the Gradio interface
235
  with gr.Blocks(css="""
236
+ .vertical-slider input[type=range] {
237
+ writing-mode: bt-lr; /* IE */
238
+ -webkit-appearance: slider-vertical; /* WebKit */
239
+ width: 8px;
240
+ height: 200px;
241
+ }
242
  .slider-container {
243
+ display: inline-block;
244
+ margin-right: 50px;
245
  text-align: center;
 
 
 
 
 
 
 
 
 
246
  }
247
  """) as demo:
248
 
 
261
  # Tabs for Beam Prediction and LoS/NLoS Classification
262
  with gr.Tab("Beam Prediction Task"):
263
  gr.Markdown("### Beam Prediction Task")
264
+
265
  with gr.Row():
266
  with gr.Column(elem_id="slider-container"):
267
+ gr.Markdown("Percentage of Data for Training")
268
+ percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
269
+
270
+ with gr.Row():
271
+ raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
272
+ embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
273
+
274
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
275
 
276
  with gr.Tab("LoS/NLoS Classification Task"):
277
  gr.Markdown("### LoS/NLoS Classification Task")
278
+
279
  file_input = gr.File(label="Upload HDF5 Dataset", file_types=[".h5"])
280
+
281
  with gr.Row():
282
  with gr.Column(elem_id="slider-container"):
283
+ gr.Markdown("Percentage of Data for Training")
284
+ percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
285
+
286
+ with gr.Row():
287
+ raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
288
+ embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
289
+ output_textbox = gr.Textbox(label="Console Output", lines=10)
290
+
291
  file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
292
  percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
293