Spaces:
Running
Running
File size: 9,349 Bytes
8f8b054 cacf045 2a77201 340b448 8f8b054 2a77201 8f8b054 c7b6b77 8f8b054 c7b6b77 8f8b054 c7b6b77 8f8b054 c7b6b77 8f8b054 c7b6b77 8f8b054 c7b6b77 8f8b054 2587718 2a77201 0176215 7ee077b cacf045 2a77201 d2d9264 0176215 2587718 9185ad4 2587718 d2d9264 9185ad4 8aef216 92aed37 7ee077b 9185ad4 7ee077b d2d9264 7ee077b 2587718 7ee077b d2d9264 4e62ce0 0176215 4e62ce0 2587718 7ee077b 4e62ce0 ca32c10 2a77201 7e7ba0a 0176215 2a77201 d2d9264 8f8b054 cacf045 8f8b054 2a77201 8f8b054 cacf045 8f8b054 cacf045 8f8b054 2a77201 8f8b054 2a77201 8f8b054 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import gradio as gr
import os
from PIL import Image
import numpy as np
import pickle
import io
import sys
import torch
import torch
import subprocess
# Paths to the predefined images folder
RAW_PATH = os.path.join("images", "raw")
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
GENERATED_PATH = os.path.join("images", "generated")
# Specific values for percentage and complexity
percentage_values = [10, 30, 50, 70, 100]
complexity_values = [16, 32]
# Custom class to capture print output
class PrintCapture(io.StringIO):
def __init__(self):
super().__init__()
self.output = []
def write(self, txt):
self.output.append(txt)
super().write(txt)
def get_output(self):
return ''.join(self.output)
# Function to load and display predefined images based on user selection
#def display_predefined_images(percentage_idx, complexity_idx):
# percentage = percentage_values[percentage_idx]
# complexity = complexity_values[complexity_idx]
# raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# raw_image = Image.open(raw_image_path)
# embeddings_image = Image.open(embeddings_image_path)
# return raw_image, embeddings_image
def display_predefined_images(percentage_idx, complexity_idx):
# Map the slider index to the actual value
percentage = percentage_values[percentage_idx]
complexity = complexity_values[complexity_idx]
# Generate the paths to the images
raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# Check if the images exist
if not os.path.exists(raw_image_path):
return None, None # Or handle the error appropriately
if not os.path.exists(embeddings_image_path):
return None, None # Or handle the error appropriately
# Load images using PIL
raw_image = Image.open(raw_image_path)
embeddings_image = Image.open(embeddings_image_path)
# Return the loaded images
return raw_image, embeddings_image
# Function to load the pre-trained model from your cloned repository
def load_custom_model():
from lwm_model import LWM # Assuming the model is defined in lwm_model.py
model = LWM() # Modify this according to your model initialization
model.eval()
return model
import importlib.util
# Function to process the uploaded .p file and perform inference using the custom model
def process_p_file(uploaded_file, percentage_idx, complexity_idx):
capture = PrintCapture()
sys.stdout = capture # Redirect print statements to capture
try:
model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
model_repo_dir = "./LWM"
# Step 1: Clone the repository if not already done
if not os.path.exists(model_repo_dir):
print(f"Cloning model repository from {model_repo_url}...")
subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
# Step 2: Verify the repository was cloned
if os.path.exists(model_repo_dir):
os.chdir(model_repo_dir)
print(f"Changed working directory to {os.getcwd()}")
print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
else:
print(f"Directory {model_repo_dir} does not exist.")
return
# Step 3: Dynamically import lwm_model.py using importlib
lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
if not os.path.exists(lwm_model_path):
print(f"Error: lwm_model.py not found at {lwm_model_path}")
return f"Error: lwm_model.py not found at {lwm_model_path}"
# Use importlib to dynamically load lwm_model.py
spec = importlib.util.spec_from_file_location("lwm_model", lwm_model_path)
lwm_model = importlib.util.module_from_spec(spec)
spec.loader.exec_module(lwm_model)
# Step 4: Load the model from LWM module
device = 'cpu'
print(f"Loading the LWM model on {device}...")
model = lwm_model.LWM.from_pretrained(device=device)
# Step 5: Import tokenizer and load data
from input_preprocess import tokenizer
with open(uploaded_file.name, 'rb') as f:
manual_data = pickle.load(f)
preprocessed_chs = tokenizer(manual_data=manual_data)
# Step 6: Perform inference
from inference import lwm_inference, create_raw_dataset
output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
output_raw = create_raw_dataset(preprocessed_chs, device)
print(f"Output Embeddings Shape: {output_emb.shape}")
print(f"Output Raw Shape: {output_raw.shape}")
return output_emb, output_raw, capture.get_output()
except Exception as e:
return str(e), str(e), capture.get_output()
finally:
sys.stdout = sys.__stdout__ # Reset print statements
# Function to handle logic based on whether a file is uploaded or not
def los_nlos_classification(file, percentage_idx, complexity_idx):
if file is not None:
return process_p_file(file, percentage_idx, complexity_idx)
else:
return display_predefined_images(percentage_idx, complexity_idx), None
# Define the Gradio interface
with gr.Blocks(css="""
.vertical-slider input[type=range] {
writing-mode: bt-lr; /* IE */
-webkit-appearance: slider-vertical; /* WebKit */
width: 8px;
height: 200px;
}
.slider-container {
display: inline-block;
margin-right: 50px;
text-align: center;
}
""") as demo:
# Contact Section
gr.Markdown(
"""
## Contact
<div style="display: flex; align-items: center;">
<a target="_blank" href="https://www.wi-lab.net"><img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;"></a>
<a target="_blank" href="mailto:alikhani@asu.edu"><img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail " alt="Email"></a>
</div>
"""
)
# Tabs for Beam Prediction and LoS/NLoS Classification
with gr.Tab("Beam Prediction Task"):
gr.Markdown("### Beam Prediction Task")
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
with gr.Tab("LoS/NLoS Classification Task"):
gr.Markdown("### LoS/NLoS Classification Task")
file_input = gr.File(label="Upload .p File", file_types=[".p"])
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
output_textbox = gr.Textbox(label="Console Output", lines=10)
file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
# Launch the app
if __name__ == "__main__":
demo.launch()
|