Sadjad Alikhani
Update app.py
340b448 verified
raw
history blame
8.98 kB
import gradio as gr
import os
from PIL import Image
import numpy as np
import pickle
import io
import sys
import torch
import torch
import subprocess
# Paths to the predefined images folder
RAW_PATH = os.path.join("images", "raw")
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
GENERATED_PATH = os.path.join("images", "generated")
# Specific values for percentage and complexity
percentage_values = [10, 30, 50, 70, 100]
complexity_values = [16, 32]
# Custom class to capture print output
class PrintCapture(io.StringIO):
def __init__(self):
super().__init__()
self.output = []
def write(self, txt):
self.output.append(txt)
super().write(txt)
def get_output(self):
return ''.join(self.output)
# Function to load and display predefined images based on user selection
def display_predefined_images(percentage_idx, complexity_idx):
percentage = percentage_values[percentage_idx]
complexity = complexity_values[complexity_idx]
raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
raw_image = Image.open(raw_image_path)
embeddings_image = Image.open(embeddings_image_path)
return raw_image, embeddings_image
# Function to load the pre-trained model from your cloned repository
def load_custom_model():
from lwm_model import LWM # Assuming the model is defined in lwm_model.py
model = LWM() # Modify this according to your model initialization
model.eval()
return model
# Function to process the uploaded .p file and perform inference using the custom model
def process_p_file(uploaded_file, percentage_idx, complexity_idx):
capture = PrintCapture()
sys.stdout = capture # Redirect print statements to capture
try:
model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
model_repo_dir = "./LWM"
# Step 1: Clone the model repository if not already cloned
if not os.path.exists(model_repo_dir):
print(f"Cloning model repository from {model_repo_url}...")
subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
# Debugging: Check if the directory exists and print contents
if os.path.exists(model_repo_dir):
os.chdir(model_repo_dir)
print(f"Changed working directory to {os.getcwd()}")
print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
else:
print(f"Directory {model_repo_dir} does not exist.")
return
# Step 2: Add the cloned repo to sys.path for imports
if model_repo_dir not in sys.path:
sys.path.append(model_repo_dir)
# Debugging: Print sys.path to ensure the cloned repo is in the path
print(f"sys.path: {sys.path}")
# Step 3: Dynamically import the model after cloning
try:
from lwm_model import LWM # Custom model in the cloned repo
print("Successfully imported LWM model.")
except ImportError as e:
print(f"Error importing LWM model: {e}")
print("Make sure lwm_model.py exists in the cloned repository.")
return
# Step 4: Check if GPU is available and set the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load the model from the cloned repository
model = LWM.from_pretrained(device=device)
# Step 5: Import the tokenizer
try:
from input_preprocess import tokenizer
except ImportError as e:
print(f"Error importing tokenizer: {e}")
return
# Step 6: Load the uploaded .p file (wireless channel matrix)
with open(uploaded_file.name, 'rb') as f:
manual_data = pickle.load(f)
# Step 7: Tokenize the data if needed (or perform any necessary preprocessing)
preprocessed_chs = tokenizer(manual_data=manual_data)
# Step 8: Perform inference using the model
from inference import lwm_inference, create_raw_dataset
output_emb = lwm_inference(preprocessed_chs, 'channel_emb', model)
output_raw = create_raw_dataset(preprocessed_chs, device)
print(f"Output Embeddings Shape: {output_emb.shape}")
print(f"Output Raw Shape: {output_raw.shape}")
# Return the embeddings, raw output, and captured output
return output_emb, output_raw, capture.get_output()
except Exception as e:
# Handle exceptions and return the captured output
return str(e), str(e), capture.get_output()
finally:
sys.stdout = sys.__stdout__ # Reset stdout
# Function to handle logic based on whether a file is uploaded or not
def los_nlos_classification(file, percentage_idx, complexity_idx):
if file is not None:
return process_p_file(file, percentage_idx, complexity_idx)
else:
return display_predefined_images(percentage_idx, complexity_idx), None
# Define the Gradio interface
with gr.Blocks(css="""
.vertical-slider input[type=range] {
writing-mode: bt-lr; /* IE */
-webkit-appearance: slider-vertical; /* WebKit */
width: 8px;
height: 200px;
}
.slider-container {
display: inline-block;
margin-right: 50px;
text-align: center;
}
""") as demo:
# Contact Section
gr.Markdown(
"""
## Contact
<div style="display: flex; align-items: center;">
<a target="_blank" href="https://www.wi-lab.net"><img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;"></a>&nbsp;&nbsp;
<a target="_blank" href="mailto:alikhani@asu.edu"><img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail " alt="Email"></a>&nbsp;&nbsp;
</div>
"""
)
# Tabs for Beam Prediction and LoS/NLoS Classification
with gr.Tab("Beam Prediction Task"):
gr.Markdown("### Beam Prediction Task")
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
with gr.Tab("LoS/NLoS Classification Task"):
gr.Markdown("### LoS/NLoS Classification Task")
file_input = gr.File(label="Upload .p File", file_types=[".p"])
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
output_textbox = gr.Textbox(label="Console Output", lines=10)
file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
# Launch the app
if __name__ == "__main__":
demo.launch()