Sadjad Alikhani
Update app.py
1012e18 verified
raw
history blame
10.2 kB
import gradio as gr
import os
from PIL import Image
import numpy as np
import pickle
import io
import sys
import torch
import torch
import subprocess
# Paths to the predefined images folder
RAW_PATH = os.path.join("images", "raw")
EMBEDDINGS_PATH = os.path.join("images", "embeddings")
GENERATED_PATH = os.path.join("images", "generated")
# Specific values for percentage and complexity
percentage_values = [10, 30, 50, 70, 100]
complexity_values = [16, 32]
# Custom class to capture print output
class PrintCapture(io.StringIO):
def __init__(self):
super().__init__()
self.output = []
def write(self, txt):
self.output.append(txt)
super().write(txt)
def get_output(self):
return ''.join(self.output)
# Function to load and display predefined images based on user selection
#def display_predefined_images(percentage_idx, complexity_idx):
# percentage = percentage_values[percentage_idx]
# complexity = complexity_values[complexity_idx]
# raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# raw_image = Image.open(raw_image_path)
# embeddings_image = Image.open(embeddings_image_path)
# return raw_image, embeddings_image
def display_predefined_images(percentage_idx, complexity_idx):
# Map the slider index to the actual value
percentage = percentage_values[percentage_idx]
complexity = complexity_values[complexity_idx]
# Generate the paths to the images
raw_image_path = os.path.join(RAW_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
embeddings_image_path = os.path.join(EMBEDDINGS_PATH, f"percentage_{percentage}_complexity_{complexity}.png")
# Check if the images exist
if not os.path.exists(raw_image_path):
return None, None # Or handle the error appropriately
if not os.path.exists(embeddings_image_path):
return None, None # Or handle the error appropriately
# Load images using PIL
raw_image = Image.open(raw_image_path)
embeddings_image = Image.open(embeddings_image_path)
# Return the loaded images
return raw_image, embeddings_image
# Function to load the pre-trained model from your cloned repository
def load_custom_model():
from lwm_model import LWM # Assuming the model is defined in lwm_model.py
model = LWM() # Modify this according to your model initialization
model.eval()
return model
import importlib.util
# Function to dynamically load a Python module from a given file path
def load_module_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
# Function to process the uploaded .p file and perform inference using the custom model
def process_p_file(uploaded_file, percentage_idx, complexity_idx):
capture = PrintCapture()
sys.stdout = capture # Redirect print statements to capture
try:
model_repo_url = "https://huggingface.co/sadjadalikhani/LWM"
model_repo_dir = "./LWM"
# Step 1: Clone the repository if not already done
if not os.path.exists(model_repo_dir):
print(f"Cloning model repository from {model_repo_url}...")
subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
# Step 2: Verify the repository was cloned and change the working directory
if os.path.exists(model_repo_dir):
os.chdir(model_repo_dir)
print(f"Changed working directory to {os.getcwd()}")
print(f"Directory content: {os.listdir(os.getcwd())}") # Debugging: Check repo content
else:
print(f"Directory {model_repo_dir} does not exist.")
return
# Step 3: Dynamically load lwm_model.py, input_preprocess.py, and inference.py
lwm_model_path = os.path.join(os.getcwd(), 'lwm_model.py')
input_preprocess_path = os.path.join(os.getcwd(), 'input_preprocess.py')
inference_path = os.path.join(os.getcwd(), 'inference.py')
# Load lwm_model
if os.path.exists(lwm_model_path):
lwm_model = load_module_from_path("lwm_model", lwm_model_path)
else:
return f"Error: lwm_model.py not found at {lwm_model_path}"
# Load input_preprocess
if os.path.exists(input_preprocess_path):
input_preprocess = load_module_from_path("input_preprocess", input_preprocess_path)
else:
return f"Error: input_preprocess.py not found at {input_preprocess_path}"
# Load inference
if os.path.exists(inference_path):
inference = load_module_from_path("inference", inference_path)
else:
return f"Error: inference.py not found at {inference_path}"
# Step 4: Load the model from lwm_model module
device = 'cpu'
print(f"Loading the LWM model on {device}...")
model = lwm_model.LWM.from_pretrained(device=device)
# Step 5: Tokenize the data using the tokenizer from input_preprocess
with open(uploaded_file.name, 'rb') as f:
manual_data = pickle.load(f)
preprocessed_chs = input_preprocess.tokenizer(manual_data=manual_data)
# Step 6: Perform inference using the functions from inference.py
output_emb = inference.lwm_inference(preprocessed_chs, 'channel_emb', model)
output_raw = inference.create_raw_dataset(preprocessed_chs, device)
print(f"Output Embeddings Shape: {output_emb.shape}")
print(f"Output Raw Shape: {output_raw.shape}")
return output_emb, output_raw, capture.get_output()
except Exception as e:
return str(e), str(e), capture.get_output()
finally:
sys.stdout = sys.__stdout__ # Reset print statements
# Function to handle logic based on whether a file is uploaded or not
def los_nlos_classification(file, percentage_idx, complexity_idx):
if file is not None:
return process_p_file(file, percentage_idx, complexity_idx)
else:
return display_predefined_images(percentage_idx, complexity_idx), None
# Define the Gradio interface
with gr.Blocks(css="""
.vertical-slider input[type=range] {
writing-mode: bt-lr; /* IE */
-webkit-appearance: slider-vertical; /* WebKit */
width: 8px;
height: 200px;
}
.slider-container {
display: inline-block;
margin-right: 50px;
text-align: center;
}
""") as demo:
# Contact Section
gr.Markdown(
"""
## Contact
<div style="display: flex; align-items: center;">
<a target="_blank" href="https://www.wi-lab.net"><img src="https://www.wi-lab.net/wp-content/uploads/2021/08/WI-name.png" alt="Wireless Model" style="height: 30px;"></a>&nbsp;&nbsp;
<a target="_blank" href="mailto:alikhani@asu.edu"><img src="https://img.shields.io/badge/email-alikhani@asu.edu-blue.svg?logo=gmail " alt="Email"></a>&nbsp;&nbsp;
</div>
"""
)
# Tabs for Beam Prediction and LoS/NLoS Classification
with gr.Tab("Beam Prediction Task"):
gr.Markdown("### Beam Prediction Task")
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_bp = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_bp = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_bp = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_bp = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
percentage_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
complexity_slider_bp.change(fn=display_predefined_images, inputs=[percentage_slider_bp, complexity_slider_bp], outputs=[raw_img_bp, embeddings_img_bp])
with gr.Tab("LoS/NLoS Classification Task"):
gr.Markdown("### LoS/NLoS Classification Task")
file_input = gr.File(label="Upload .p File", file_types=[".p"])
with gr.Row():
with gr.Column(elem_id="slider-container"):
gr.Markdown("Percentage of Data for Training")
percentage_slider_los = gr.Slider(minimum=0, maximum=4, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Column(elem_id="slider-container"):
gr.Markdown("Task Complexity")
complexity_slider_los = gr.Slider(minimum=0, maximum=1, step=1, value=0, interactive=True, elem_id="vertical-slider")
with gr.Row():
raw_img_los = gr.Image(label="Raw Channels", type="pil", width=300, height=300, interactive=False)
embeddings_img_los = gr.Image(label="Embeddings", type="pil", width=300, height=300, interactive=False)
output_textbox = gr.Textbox(label="Console Output", lines=10)
file_input.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
percentage_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
complexity_slider_los.change(fn=los_nlos_classification, inputs=[file_input, percentage_slider_los, complexity_slider_los], outputs=[raw_img_los, embeddings_img_los, output_textbox])
# Launch the app
if __name__ == "__main__":
demo.launch()