lwm / utils /beamforming.py
wi-lab's picture
upload side scripts
713dc9d verified
#%% PACKAGES & MODEULS
import numpy as np
import torch
from input_preprocess import DeepMIMO_data_gen, deepmimo_data_cleaning, tokenizer
from inference import lwm_inference, create_raw_dataset
from lwm_model import lwm
#%% DEEPMIMO DATA GENERATION
scenario_names = np.array([
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
])
bf_scenario_idx = 3
scenario_idxs = np.array([bf_scenario_idx])
selected_scenario_names = scenario_names[scenario_idxs]
deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))]
#%% FUNCTION FOR MRT BEAMFORMING
def compute_mrt_beamforming(channel_data, snr_db=None):
channel_data = torch.tensor(channel_data[0])
mrt_vectors = []
snr_linear = 10 ** (snr_db / 10) if snr_db is not None else None
for idx in range(channel_data.shape[0]):
channel = channel_data[idx, 0, :, :] # Shape: (32, 32)
if snr_db is not None:
# Add complex Gaussian noise to the channel
noise_power = torch.mean(torch.abs(channel) ** 2) / snr_linear
noise = torch.sqrt(noise_power / 2) * (
torch.randn_like(channel) + 1j * torch.randn_like(channel)
)
channel = channel + noise
# Compute MRT beamforming vector for each user
h_avg = torch.mean(channel, dim=1, keepdim=True) # Shape: (32, 1)
h_conj = torch.conj(h_avg) # Conjugate of averaged channel vector
mrt_vector = h_conj / torch.norm(h_conj, dim=0, keepdim=True) # Normalize
mrt_vectors.append(mrt_vector)
return torch.stack(mrt_vectors, dim=0) # Shape: (N, 32, 1)
#%% GENERATE BEAMFORMING VECTORS
beamforming_vectors = compute_mrt_beamforming(cleaned_deepmimo_data)
#%% GENERATE LWM EMBEDDINGS FROM MASKED INPUT CHANNELS
preprocessed_chs = tokenizer(
selected_scenario_names=selected_scenario_names,
manual_data=None,
gen_raw=False) # gen_raw=False masks 15% of the input patches, and LWM will act as a denoiser
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading the LWM model on {device} ...")
model = lwm.from_pretrained(device=device)
input_types = ['cls_emb', 'channel_emb', 'raw']
selected_input_type = input_types[1]
if selected_input_type in ['cls_emb', 'channel_emb']:
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
else:
dataset = create_raw_dataset(preprocessed_chs, device)