π‘ LWM: Large Wireless Model
π Click here to try the Interactive Demo!
Welcome to the LWM (Large Wireless Model) repository! This project hosts a pre-trained model designed to process and extract features from wireless communication datasets, specifically the DeepMIMO dataset. Follow the instructions below to clone the repository, load the data, and perform inference with LWM.
π How to Use
1. Clone the Repository
To get started, clone the Hugging Face repository to your local machine with the following Python code:
import subprocess
import os
import sys
import importlib.util
import torch
# Hugging Face public repository URL
repo_url = "https://huggingface.co/sadjadalikhani/LWM"
# Directory where the repo will be cloned
clone_dir = "./LWM"
# Step 1: Clone the repository if it hasn't been cloned already
if not os.path.exists(clone_dir):
print(f"Cloning repository from {repo_url} into {clone_dir}...")
result = subprocess.run(["git", "clone", repo_url, clone_dir], capture_output=True, text=True)
if result.returncode != 0:
print(f"Error cloning repository: {result.stderr}")
sys.exit(1)
print(f"Repository cloned successfully into {clone_dir}")
else:
print(f"Repository already cloned into {clone_dir}")
# Step 2: Add the cloned directory to Python path
sys.path.append(clone_dir)
# Step 3: Import necessary functions
def import_functions_from_file(module_name, file_path):
try:
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for function_name in dir(module):
if callable(getattr(module, function_name)) and not function_name.startswith("__"):
globals()[function_name] = getattr(module, function_name)
return module
except FileNotFoundError:
print(f"Error: {file_path} not found!")
sys.exit(1)
# Step 4: Import functions from the repository
import_functions_from_file("lwm_model", os.path.join(clone_dir, "lwm_model.py"))
import_functions_from_file("inference", os.path.join(clone_dir, "inference.py"))
import_functions_from_file("load_data", os.path.join(clone_dir, "load_data.py"))
import_functions_from_file("input_preprocess", os.path.join(clone_dir, "input_preprocess.py"))
print("All required functions imported successfully.")
2. Load the DeepMIMO Dataset
Before tokenizing and processing the data, you need to load the DeepMIMO dataset. Below is a list of available datasets and their links for more information:
π Dataset Overview
Dataset | City | Number of Users | DeepMIMO Page |
---|---|---|---|
Dataset 0 | π Denver | 1354 | DeepMIMO City Scenario 18 |
Dataset 1 | ποΈ Indianapolis | 3248 | DeepMIMO City Scenario 15 |
Dataset 2 | π Oklahoma | 3455 | DeepMIMO City Scenario 19 |
Dataset 3 | π Fort Worth | 1902 | DeepMIMO City Scenario 12 |
Dataset 4 | π Santa Clara | 2689 | DeepMIMO City Scenario 11 |
Dataset 5 | π San Diego | 2192 | DeepMIMO City Scenario 7 |
Operational Settings:
- Antennas at BS: 32
- Antennas at UEs: 1
- Subcarriers: 32
- Paths: 20
Load Data Code:
Select and load specific datasets by adjusting the dataset_idxs
. In the example below, we select the first two datasets.
# Step 5: Load the DeepMIMO dataset
print("Loading the DeepMIMO dataset...")
# Load the DeepMIMO dataset
deepmimo_data = load_DeepMIMO_data()
# Select datasets to load
dataset_idxs = torch.arange(2) # Adjust the number of datasets as needed
print("DeepMIMO dataset loaded successfully.")
3. Tokenize the DeepMIMO Dataset
After loading the data, tokenize the selected DeepMIMO datasets. This step prepares the data for the model to process.
Tokenization Code:
# Step 6: Tokenize the dataset
print("Tokenizing the DeepMIMO dataset...")
# Tokenize the loaded datasets
preprocessed_chs = tokenizer(deepmimo_data, dataset_idxs, gen_raw=True)
print("Dataset tokenized successfully.")
4. Load the LWM Model
Once the dataset is tokenized, load the pre-trained LWM model using the following code:
# Step 7: Load the LWM model (with flexibility for the device)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading the LWM model on {device}...")
model = LWM.from_pretrained(device=device)
5. LWM Inference
Once the dataset is tokenized and the model is loaded, generate either raw channels or the inferred LWM embeddings by choosing the input type.
# Step 8: Generate the dataset for inference
input_type = ['cls_emb', 'channel_emb', 'raw'][1] # Modify input type as needed
dataset = dataset_gen(preprocessed_chs, input_type, model)
You can choose between:
cls_emb
: LWM CLS token embeddingschannel_emb
: LWM channel embeddingsraw
: Raw wireless channel data
π Post-processing for Downstream Task
1. Use the Dataset in Downstream Tasks
Finally, use the generated dataset for your downstream tasks, such as classification, prediction, or analysis.
# Step 9: Print results
print(f"Dataset generated with shape: {dataset.shape}")
print("Inference completed successfully.")
π Requirements
- Python 3.x
- PyTorch