Sadjad Alikhani commited on
Commit
2a77201
·
verified ·
1 Parent(s): cacf045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -3,6 +3,8 @@ import os
3
  from PIL import Image
4
  import numpy as np
5
  import pickle
 
 
6
 
7
  # Paths to the predefined images folder
8
  RAW_PATH = os.path.join("images", "raw")
@@ -13,21 +15,29 @@ GENERATED_PATH = os.path.join("images", "generated")
13
  percentage_values = [10, 30, 50, 70, 100]
14
  complexity_values = [16, 32]
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Function to load and display predefined images based on user selection
17
  def display_predefined_images(percentage_idx, complexity_idx):
18
- # Map the slider index to the actual value
19
  percentage = percentage_values[percentage_idx]
20
  complexity = complexity_values[complexity_idx]
21
-
22
- # Generate the paths to the images
23
  raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
24
  embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
25
 
26
- # Load images using PIL
27
  raw_image = Image.open(raw_image_path)
28
  embeddings_image = Image.open(embeddings_image_path)
29
 
30
- # Return the loaded images
31
  return raw_image, embeddings_image
32
 
33
  import torch
@@ -35,16 +45,17 @@ import subprocess
35
 
36
  # Function to load the pre-trained model from your cloned repository
37
  def load_custom_model():
38
- # Assume your model is in the cloned LWM repository
39
  from lwm_model import LWM # Assuming the model is defined in lwm_model.py
40
  model = LWM() # Modify this according to your model initialization
41
- model.eval() # Set the model to evaluation mode
42
  return model
43
 
44
  # Function to process the uploaded .p file and perform inference using the custom model
45
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
 
 
 
46
  try:
47
- # Clone the repository if not already done (for model and tokenizer)
48
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
49
  model_repo_dir = "./LWM"
50
 
@@ -52,49 +63,45 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
52
  print(f"Cloning model repository from {model_repo_url}...")
53
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
54
 
55
- # Change the working directory to the cloned LWM folder
56
  if os.path.exists(model_repo_dir):
57
  os.chdir(model_repo_dir)
58
  print(f"Changed working directory to {os.getcwd()}")
59
  else:
60
  return f"Directory {model_repo_dir} does not exist."
61
 
62
- # Step 1: Load the custom model
63
  from lwm_model import LWM
64
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
65
  print(f"Loading the LWM model on {device}...")
66
  model = LWM.from_pretrained(device=device)
67
 
68
- # Step 2: Import the tokenizer
69
  from input_preprocess import tokenizer
70
 
71
- # Step 3: Load the uploaded .p file that contains the wireless channel matrix
72
  with open(uploaded_file.name, 'rb') as f:
73
  manual_data = pickle.load(f)
74
 
75
- # Step 4: Tokenize the data if needed (or perform any necessary preprocessing)
76
  preprocessed_chs = tokenizer(manual_data=manual_data)
77
 
78
- # Step 5: Perform inference on the channel matrix using the model
79
  from inference import lwm_inference, create_raw_dataset
80
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
81
  output_raw = create_raw_dataset(preprocessed_chs, device)
82
- print(output_emb.shape)
83
- print(output_raw.shape)
84
 
85
- return output_emb, output_raw
 
 
 
86
 
87
  except Exception as e:
88
- return str(e), str(e)
 
 
 
89
 
90
  # Function to handle logic based on whether a file is uploaded or not
91
  def los_nlos_classification(file, percentage_idx, complexity_idx):
92
  if file is not None:
93
- # Process the uploaded .p file and generate new images
94
  return process_p_file(file, percentage_idx, complexity_idx)
95
  else:
96
- # Display predefined images if no file is uploaded
97
- return display_predefined_images(percentage_idx, complexity_idx)
98
 
99
  # Define the Gradio interface
100
  with gr.Blocks(css="""
@@ -126,7 +133,6 @@ with gr.Blocks(css="""
126
  with gr.Tab("Beam Prediction Task"):
127
  gr.Markdown("### Beam Prediction Task")
128
 
129
- # Sliders for percentage and complexity
130
  with gr.Row():
131
  with gr.Column(elem_id="slider-container"):
132
  gr.Markdown("Percentage of Data for Training")
@@ -135,22 +141,18 @@ with gr.Blocks(css="""
135
  gr.Markdown("Task Complexity")
136
  complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
137
 
138
- # Image outputs (display the images side by side and set a smaller size for the images)
139
  with gr.Row():
140
  raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
141
  embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
142
 
143
- # Instant image updates when sliders change
144
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
145
  complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
146
 
147
  with gr.Tab("LoS/NLoS Classification Task"):
148
  gr.Markdown("### LoS/NLoS Classification Task")
149
 
150
- # File uploader for uploading .p file
151
  file_input = gr.File(label="Upload .p File", file_types=[".p"])
152
 
153
- # Sliders for percentage and complexity
154
  with gr.Row():
155
  with gr.Column(elem_id="slider-container"):
156
  gr.Markdown("Percentage of Data for Training")
@@ -159,15 +161,14 @@ with gr.Blocks(css="""
159
  gr.Markdown("Task Complexity")
160
  complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
161
 
162
- # Image outputs (display the images side by side and set a smaller size for the images)
163
  with gr.Row():
164
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
165
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
 
166
 
167
- # Instant image updates based on file upload or slider changes
168
- file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
169
- percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
170
- complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los])
171
 
172
  # Launch the app
173
  if __name__ == "__main__":
 
3
  from PIL import Image
4
  import numpy as np
5
  import pickle
6
+ import io
7
+ import sys
8
 
9
  # Paths to the predefined images folder
10
  RAW_PATH = os.path.join("images", "raw")
 
15
  percentage_values = [10, 30, 50, 70, 100]
16
  complexity_values = [16, 32]
17
 
18
+ # Custom class to capture print output
19
+ class PrintCapture(io.StringIO):
20
+ def __init__(self):
21
+ super().__init__()
22
+ self.output = []
23
+
24
+ def write(self, txt):
25
+ self.output.append(txt)
26
+ super().write(txt)
27
+
28
+ def get_output(self):
29
+ return ''.join(self.output)
30
+
31
  # Function to load and display predefined images based on user selection
32
  def display_predefined_images(percentage_idx, complexity_idx):
 
33
  percentage = percentage_values[percentage_idx]
34
  complexity = complexity_values[complexity_idx]
 
 
35
  raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
36
  embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
37
 
 
38
  raw_image = Image.open(raw_image_path)
39
  embeddings_image = Image.open(embeddings_image_path)
40
 
 
41
  return raw_image, embeddings_image
42
 
43
  import torch
 
45
 
46
  # Function to load the pre-trained model from your cloned repository
47
  def load_custom_model():
 
48
  from lwm_model import LWM # Assuming the model is defined in lwm_model.py
49
  model = LWM() # Modify this according to your model initialization
50
+ model.eval()
51
  return model
52
 
53
  # Function to process the uploaded .p file and perform inference using the custom model
54
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
55
+ capture = PrintCapture()
56
+ sys.stdout = capture # Redirect print statements to capture
57
+
58
  try:
 
59
  model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
60
  model_repo_dir = "./LWM"
61
 
 
63
  print(f"Cloning model repository from {model_repo_url}...")
64
  subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
65
 
 
66
  if os.path.exists(model_repo_dir):
67
  os.chdir(model_repo_dir)
68
  print(f"Changed working directory to {os.getcwd()}")
69
  else:
70
  return f"Directory {model_repo_dir} does not exist."
71
 
 
72
  from lwm_model import LWM
73
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
  print(f"Loading the LWM model on {device}...")
75
  model = LWM.from_pretrained(device=device)
76
 
 
77
  from input_preprocess import tokenizer
78
 
 
79
  with open(uploaded_file.name, 'rb') as f:
80
  manual_data = pickle.load(f)
81
 
 
82
  preprocessed_chs = tokenizer(manual_data=manual_data)
83
 
 
84
  from inference import lwm_inference, create_raw_dataset
85
  output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
86
  output_raw = create_raw_dataset(preprocessed_chs, device)
 
 
87
 
88
+ print(f"Output Embeddings Shape: {output_emb.shape}")
89
+ print(f"Output Raw Shape: {output_raw.shape}")
90
+
91
+ return output_emb, output_raw, capture.get_output()
92
 
93
  except Exception as e:
94
+ return str(e), str(e), capture.get_output()
95
+
96
+ finally:
97
+ sys.stdout = sys.__stdout__ # Reset print statements
98
 
99
  # Function to handle logic based on whether a file is uploaded or not
100
  def los_nlos_classification(file, percentage_idx, complexity_idx):
101
  if file is not None:
 
102
  return process_p_file(file, percentage_idx, complexity_idx)
103
  else:
104
+ return display_predefined_images(percentage_idx, complexity_idx), None
 
105
 
106
  # Define the Gradio interface
107
  with gr.Blocks(css="""
 
133
  with gr.Tab("Beam Prediction Task"):
134
  gr.Markdown("### Beam Prediction Task")
135
 
 
136
  with gr.Row():
137
  with gr.Column(elem_id="slider-container"):
138
  gr.Markdown("Percentage of Data for Training")
 
141
  gr.Markdown("Task Complexity")
142
  complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
143
 
 
144
  with gr.Row():
145
  raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
146
  embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
147
 
 
148
  percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
149
  complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
150
 
151
  with gr.Tab("LoS/NLoS Classification Task"):
152
  gr.Markdown("### LoS/NLoS Classification Task")
153
 
 
154
  file_input = gr.File(label="Upload .p File", file_types=[".p"])
155
 
 
156
  with gr.Row():
157
  with gr.Column(elem_id="slider-container"):
158
  gr.Markdown("Percentage of Data for Training")
 
161
  gr.Markdown("Task Complexity")
162
  complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
163
 
 
164
  with gr.Row():
165
  raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
166
  embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
167
+ output_textbox = gr.Textbox(label="Console Output", lines=10)
168
 
169
+ file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
170
+ percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
171
+ complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
 
172
 
173
  # Launch the app
174
  if __name__ == "__main__":