Sadjad Alikhani commited on
Commit
8030161
·
verified ·
1 Parent(s): 17185bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -101
app.py CHANGED
@@ -6,13 +6,11 @@ import pickle
6
  import io
7
  import sys
8
  import torch
9
- import torch
10
  import subprocess
11
 
12
  # Paths to the predefined images folder
13
  RAW_PATH = os.path.join("images", "raw")
14
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
15
- GENERATED_PATH = os.path.join("images", "generated")
16
 
17
  # Specific values for percentage and complexity
18
  percentage_values = [10, 30, 50, 70, 100]
@@ -32,68 +30,21 @@ class PrintCapture(io.StringIO):
32
  return ''.join(self.output)
33
 
34
  # Function to load and display predefined images based on user selection
35
- #def display_predefined_images(percentage_idx, complexity_idx):
36
- # percentage = percentage_values[percentage_idx]
37
- # complexity = complexity_values[complexity_idx]
38
- # raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
39
- # embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
40
-
41
- # raw_image = Image.open(raw_image_path)
42
- # embeddings_image = Image.open(embeddings_image_path)
43
-
44
- # return raw_image, embeddings_image
45
-
46
  def display_predefined_images(percentage_idx, complexity_idx):
47
- # Map the slider index to the actual value
48
  percentage = percentage_values[percentage_idx]
49
  complexity = complexity_values[complexity_idx]
50
-
51
- # Generate the paths to the images
52
  raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
53
  embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
54
 
55
- # Check if the images exist
56
- if not os.path.exists(raw_image_path):
57
- return None, None # Or handle the error appropriately
58
- if not os.path.exists(embeddings_image_path):
59
- return None, None # Or handle the error appropriately
60
-
61
- # Load images using PIL
62
  raw_image = Image.open(raw_image_path)
63
  embeddings_image = Image.open(embeddings_image_path)
64
 
65
- # Return the loaded images
66
  return raw_image, embeddings_image
67
-
68
-
69
- # Function to load the pre-trained model from your cloned repository
70
- def load_custom_model():
71
- from lwm_model import LWM # Assuming the model is defined in lwm_model.py
72
- model = LWM() # Modify this according to your model initialization
73
- model.eval()
74
- return model
75
-
76
- import importlib.util
77
-
78
- # Function to dynamically load a Python module from a given file path
79
- def load_module_from_path(module_name, file_path):
80
- spec = importlib.util.spec_from_file_location(module_name, file_path)
81
- module = importlib.util.module_from_spec(spec)
82
- spec.loader.exec_module(module)
83
- return module
84
-
85
- import sys
86
- import os
87
- import subprocess
88
- import pickle
89
- import importlib.util
90
 
91
- # Function to dynamically load a Python module from a given file path
92
- def load_module_from_path(module_name, file_path):
93
- spec = importlib.util.spec_from_file_location(module_name, file_path)
94
- module = importlib.util.module_from_spec(spec)
95
- spec.loader.exec_module(module)
96
- return module
97
 
98
  # Function to process the uploaded .p file and perform inference using the custom model
99
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
@@ -113,57 +64,15 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
113
  if os.path.exists(model_repo_dir):
114
  os.chdir(model_repo_dir)
115
  print(f"Changed working directory to {os.getcwd()}")
116
- print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
117
  else:
118
  print(f"Directory {model_repo_dir} does not exist.")
119
  return
120
 
121
- # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
122
- lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
123
- input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
124
- inference_path = os.path.join(os.getcwd(), 'inference.py')
125
-
126
- print(lwm_model_path)
127
- print(input_preprocess_path)
128
- print(inference_path)
129
-
130
- # Load lwm_model
131
- if os.path.exists(lwm_model_path):
132
- lwm_model = load_module_from_path("lwm_model", lwm_model_path)
133
- else:
134
- return f"Error: lwm_model.py not found at {lwm_model_path}"
135
-
136
- # Load input_preprocess
137
- if os.path.exists(input_preprocess_path):
138
- input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
139
- else:
140
- return f"Error: input_preprocess.py not found at {input_preprocess_path}"
141
-
142
- # Load inference
143
- if os.path.exists(inference_path):
144
- inference = load_module_from_path("inference", inference_path)
145
- else:
146
- return f"Error: inference.py not found at {inference_path}"
147
-
148
- # Step 4: Load the model from lwm_model module
149
- device = 'cpu'
150
- print(f"Loading the LWM model on {device}...")
151
- model = lwm_model.LWM.from_pretrained(device=device)
152
-
153
- # Step 5: Tokenize the data using the tokenizer from input_preprocess
154
- with open(uploaded_file.name, 'rb') as f:
155
- manual_data = pickle.load(f)
156
-
157
- preprocessed_chs = input_preprocess.tokenizer(manual_data=manual_data)
158
 
159
- # Step 6: Perform inference using the functions from inference.py
160
- output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
161
- output_raw = inference.create_raw_dataset(preprocessed_chs, device)
162
-
163
- print(f"Output Embeddings Shape: {output_emb.shape}")
164
- print(f"Output Raw Shape: {output_raw.shape}")
165
-
166
- return output_emb, output_raw, capture.get_output()
167
 
168
  except Exception as e:
169
  return str(e), str(e), capture.get_output()
@@ -171,13 +80,12 @@ def process_p_file(uploaded_file, percentage_idx, complexity_idx):
171
  finally:
172
  sys.stdout = sys.__stdout__ # Reset print statements
173
 
174
-
175
  # Function to handle logic based on whether a file is uploaded or not
176
  def los_nlos_classification(file, percentage_idx, complexity_idx):
177
  if file is not None:
178
  return process_p_file(file, percentage_idx, complexity_idx)
179
  else:
180
- return display_predefined_images(percentage_idx, complexity_idx), None
181
 
182
  # Define the Gradio interface
183
  with gr.Blocks(css="""
 
6
  import io
7
  import sys
8
  import torch
 
9
  import subprocess
10
 
11
  # Paths to the predefined images folder
12
  RAW_PATH = os.path.join("images", "raw")
13
  EMBEDDINGS_PATH = os.path.join("images", "embeddings")
 
14
 
15
  # Specific values for percentage and complexity
16
  percentage_values = [10, 30, 50, 70, 100]
 
30
  return ''.join(self.output)
31
 
32
  # Function to load and display predefined images based on user selection
 
 
 
 
 
 
 
 
 
 
 
33
  def display_predefined_images(percentage_idx, complexity_idx):
 
34
  percentage = percentage_values[percentage_idx]
35
  complexity = complexity_values[complexity_idx]
 
 
36
  raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
37
  embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
38
 
 
 
 
 
 
 
 
39
  raw_image = Image.open(raw_image_path)
40
  embeddings_image = Image.open(embeddings_image_path)
41
 
 
42
  return raw_image, embeddings_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Function to create random images for LoS/NLoS classification results
45
+ def create_random_image(size=(300, 300)):
46
+ random_image = np.random.rand(*size, 3) * 255
47
+ return Image.fromarray(random_image.astype('uint8'))
 
 
48
 
49
  # Function to process the uploaded .p file and perform inference using the custom model
50
  def process_p_file(uploaded_file, percentage_idx, complexity_idx):
 
64
  if os.path.exists(model_repo_dir):
65
  os.chdir(model_repo_dir)
66
  print(f"Changed working directory to {os.getcwd()}")
 
67
  else:
68
  print(f"Directory {model_repo_dir} does not exist.")
69
  return
70
 
71
+ # Simulate processing and generating random images
72
+ raw_image = create_random_image()
73
+ embeddings_image = create_random_image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ return raw_image, embeddings_image, capture.get_output()
 
 
 
 
 
 
 
76
 
77
  except Exception as e:
78
  return str(e), str(e), capture.get_output()
 
80
  finally:
81
  sys.stdout = sys.__stdout__ # Reset print statements
82
 
 
83
  # Function to handle logic based on whether a file is uploaded or not
84
  def los_nlos_classification(file, percentage_idx, complexity_idx):
85
  if file is not None:
86
  return process_p_file(file, percentage_idx, complexity_idx)
87
  else:
88
+ return create_random_image(), create_random_image(), None
89
 
90
  # Define the Gradio interface
91
  with gr.Blocks(css="""