File size: 10,151 Bytes
8f8b054
 
 
 
cacf045
2a77201
 
340b448
 
 
8f8b054
 
 
 
 
 
 
 
 
 
2a77201
 
 
 
 
 
 
 
 
 
 
 
 
8f8b054
c7b6b77
 
 
 
 
 
 
 
 
 
 
8f8b054
c7b6b77
8f8b054
 
c7b6b77
 
8f8b054
 
 
c7b6b77
 
 
 
 
 
 
8f8b054
 
 
c7b6b77
8f8b054
c7b6b77
8f8b054
2587718
 
 
 
2a77201
0176215
 
7ee077b
 
1012e18
 
 
 
 
 
 
cacf045
 
2a77201
 
d2d9264
0176215
2587718
 
 
9185ad4
2587718
 
 
d2d9264
1012e18
8aef216
 
 
 
 
 
 
92aed37
1012e18
7ee077b
1012e18
 
 
 
 
 
 
9185ad4
 
1012e18
 
 
 
 
 
 
 
 
 
 
 
 
d2d9264
 
7ee077b
2587718
1012e18
4e62ce0
 
0176215
1012e18
2587718
1012e18
 
 
ca32c10
2a77201
 
 
 
7e7ba0a
0176215
2a77201
 
 
d2d9264
8f8b054
1012e18
 
8f8b054
 
 
cacf045
8f8b054
2a77201
8f8b054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cacf045
 
8f8b054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cacf045
8f8b054
 
 
 
 
 
 
 
 
 
 
 
2a77201
8f8b054
2a77201
 
 
8f8b054
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import gradio as gr
import os
from PIL import Image
import numpy as np
import pickle
import io
import sys
import torch
import torch
import subprocess

# Paths to the predefined images folder
RAW_PATH = os.path.join("images", "raw")
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
GENERATED_PATH = os.path.join("images", "generated")

# Specific values for percentage and complexity
percentage_values = [10, 30, 50, 70, 100]
complexity_values = [16, 32]

# Custom class to capture print output
class PrintCapture(io.StringIO):
    def __init__(self):
        super().__init__()
        self.output = []

    def write(self, txt):
        self.output.append(txt)
        super().write(txt)

    def get_output(self):
        return ''.join(self.output)

# Function to load and display predefined images based on user selection
#def display_predefined_images(percentage_idx, complexity_idx):
#    percentage = percentage_values[percentage_idx]
#    complexity = complexity_values[complexity_idx]
#    raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
#    embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
    
#    raw_image = Image.open(raw_image_path)
#    embeddings_image = Image.open(embeddings_image_path)
    
#    return raw_image, embeddings_image

def display_predefined_images(percentage_idx, complexity_idx):
    # Map the slider index to the actual value
    percentage = percentage_values[percentage_idx]
    complexity = complexity_values[complexity_idx]
    
    # Generate the paths to the images
    raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
    embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
    
    # Check if the images exist
    if not os.path.exists(raw_image_path):
        return None, None  # Or handle the error appropriately
    if not os.path.exists(embeddings_image_path):
        return None, None  # Or handle the error appropriately

    # Load images using PIL
    raw_image = Image.open(raw_image_path)
    embeddings_image = Image.open(embeddings_image_path)
    
    # Return the loaded images
    return raw_image, embeddings_image
    

# Function to load the pre-trained model from your cloned repository
def load_custom_model():
    from lwm_model import LWM  # Assuming the model is defined in lwm_model.py
    model = LWM()  # Modify this according to your model initialization
    model.eval()
    return model

import importlib.util

# Function to dynamically load a Python module from a given file path
def load_module_from_path(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

# Function to process the uploaded .p file and perform inference using the custom model
def process_p_file(uploaded_file, percentage_idx, complexity_idx):
    capture = PrintCapture()
    sys.stdout = capture  # Redirect print statements to capture
    
    try:
        model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
        model_repo_dir = "./LWM"

        # Step 1: Clone the repository if not already done
        if not os.path.exists(model_repo_dir):
            print(f"Cloning model repository from {model_repo_url}...")
            subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)

        # Step 2: Verify the repository was cloned and change the working directory
        if os.path.exists(model_repo_dir):
            os.chdir(model_repo_dir)
            print(f"Changed working directory to {os.getcwd()}")
            print(f"Directory content: {os.listdir(os.getcwd())}")  # Debugging: Check repo content
        else:
            print(f"Directory {model_repo_dir} does not exist.")
            return

        # Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
        lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
        input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
        inference_path = os.path.join(os.getcwd(), 'inference.py')

        # Load lwm_model
        if os.path.exists(lwm_model_path):
            lwm_model = load_module_from_path("lwm_model", lwm_model_path)
        else:
            return f"Error: lwm_model.py not found at {lwm_model_path}"

        # Load input_preprocess
        if os.path.exists(input_preprocess_path):
            input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
        else:
            return f"Error: input_preprocess.py not found at {input_preprocess_path}"

        # Load inference
        if os.path.exists(inference_path):
            inference = load_module_from_path("inference", inference_path)
        else:
            return f"Error: inference.py not found at {inference_path}"

        # Step 4: Load the model from lwm_model module
        device = 'cpu'
        print(f"Loading the LWM model on {device}...")
        model = lwm_model.LWM.from_pretrained(device=device)

        # Step 5: Tokenize the data using the tokenizer from input_preprocess
        with open(uploaded_file.name, 'rb') as f:
            manual_data = pickle.load(f)

        preprocessed_chs = input_preprocess.tokenizer(manual_data=manual_data)

        # Step 6: Perform inference using the functions from inference.py
        output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
        output_raw = inference.create_raw_dataset(preprocessed_chs, device)

        print(f"Output Embeddings Shape: {output_emb.shape}")
        print(f"Output Raw Shape: {output_raw.shape}")

        return output_emb, output_raw, capture.get_output()

    except Exception as e:
        return str(e), str(e), capture.get_output()

    finally:
        sys.stdout = sys.__stdout__  # Reset print statements



# Function to handle logic based on whether a file is uploaded or not
def los_nlos_classification(file, percentage_idx, complexity_idx):
    if file is not None:
        return process_p_file(file, percentage_idx, complexity_idx)
    else:
        return display_predefined_images(percentage_idx, complexity_idx), None

# Define the Gradio interface
with gr.Blocks(css="""
    .vertical-slider input[type=range] {
        writing-mode: bt-lr; /* IE */
        -webkit-appearance: slider-vertical; /* WebKit */
        width: 8px;
        height: 200px;
    }
    .slider-container {
        display: inline-block;
        margin-right: 50px;
        text-align: center;
    }
""") as demo:
    
    # Contact Section
    gr.Markdown(
        """
        ## Contact
        <div style="display: flex; align-items: center;">
            <a target="_blank" href="https://www.wi-lab.net"><img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;"></a>&nbsp;&nbsp;
            <a target="_blank" href="mailto:alikhani@asu.edu"><img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail " alt="Email"></a>&nbsp;&nbsp;
        </div>
        """
    )
    
    # Tabs for Beam Prediction and LoS/NLoS Classification
    with gr.Tab("Beam Prediction Task"):
        gr.Markdown("### Beam Prediction Task")
        
        with gr.Row():
            with gr.Column(elem_id="slider-container"):
                gr.Markdown("Percentage of Data for Training")
                percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
            with gr.Column(elem_id="slider-container"):
                gr.Markdown("Task Complexity")
                complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")

        with gr.Row():
            raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
            embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)

        percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
        complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])

    with gr.Tab("LoS/NLoS Classification Task"):
        gr.Markdown("### LoS/NLoS Classification Task")
        
        file_input = gr.File(label="Upload .p File", file_types=[".p"])

        with gr.Row():
            with gr.Column(elem_id="slider-container"):
                gr.Markdown("Percentage of Data for Training")
                percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
            with gr.Column(elem_id="slider-container"):
                gr.Markdown("Task Complexity")
                complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")

        with gr.Row():
            raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
            embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
            output_textbox = gr.Textbox(label="Console Output", lines=10)

        file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
        percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
        complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])

# Launch the app
if __name__ == "__main__":
    demo.launch()