|
|
|
"""
|
|
Created on Sun Sep 15 18:27:17 2024
|
|
|
|
@author: salikha4
|
|
"""
|
|
|
|
import os
|
|
import csv
|
|
import json
|
|
import shutil
|
|
import random
|
|
import argparse
|
|
from datetime import datetime
|
|
import pandas as pd
|
|
import time
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
from torch.optim import Adam
|
|
import numpy as np
|
|
from lwm_model import LWM, load_model
|
|
import warnings
|
|
warnings.filterwarnings('ignore')
|
|
from input_preprocess import *
|
|
|
|
|
|
device_idx_ds = 3
|
|
device = torch.device(f'cuda:{device_idx_ds}' if torch.cuda.is_available() else "cpu")
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
def dataset_gen(preprocessed_chs, input_type, scenario_idxs, lwm_model):
|
|
|
|
if input_type in ['cls_emb', 'channel_emb']:
|
|
dataset = prepare_for_LWM(preprocessed_chs, device)
|
|
elif input_type == 'raw':
|
|
dataset = create_raw_dataset(preprocessed_chs, device)
|
|
|
|
if input_type in ['cls_emb','channel_emb']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
|
|
|
print(f'LWM loss: {lwm_loss:.4f}')
|
|
|
|
if input_type == 'cls_emb':
|
|
embedding_data = embedding_data[:, 0]
|
|
elif input_type == 'channel_emb':
|
|
embedding_data = embedding_data[:, 1:]
|
|
|
|
dataset = embedding_data.float()
|
|
|
|
return dataset
|
|
|
|
|
|
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
|
|
|
|
input_ids, masked_tokens, masked_pos = zip(*data)
|
|
|
|
input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
|
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
|
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
|
|
|
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
|
|
|
|
def create_raw_dataset(data, device):
|
|
"""Create a dataset for raw channel data."""
|
|
input_ids, _, _ = zip(*data)
|
|
input_data = torch.tensor(input_ids, device=device)[:, 1:]
|
|
return input_data.float()
|
|
|
|
|
|
def label_gen(task, data, scenario, n_beams=64):
|
|
|
|
idxs = np.where(data['user']['LoS'] != -1)[0]
|
|
|
|
if task == 'LoS/NLoS Classification':
|
|
label = data['user']['LoS'][idxs]
|
|
elif task == 'Beam Prediction':
|
|
parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
|
|
n_users = len(data['user']['channel'])
|
|
n_subbands = 1
|
|
fov = 120
|
|
|
|
|
|
beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
|
|
|
|
F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
|
|
phi=azi*np.pi/180,
|
|
kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
|
|
for azi in beam_angles])
|
|
|
|
full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
|
|
for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
|
|
if data['user']['LoS'][ue_idx] == -1:
|
|
full_dbm[:,:,ue_idx] = np.nan
|
|
else:
|
|
chs = F1 @ data['user']['channel'][ue_idx]
|
|
full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
|
|
full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
|
|
|
|
best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
|
|
best_beams = best_beams.astype(float)
|
|
best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
|
|
max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
|
|
|
|
label = best_beams[idxs]
|
|
|
|
return label.astype(int)
|
|
|
|
|
|
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
|
|
|
|
|
idxs = DeepMIMOv3.ant_indices(array)
|
|
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
|
return resp / np.linalg.norm(resp)
|
|
|
|
|
|
def evaluate(model, dataloader):
|
|
|
|
model.eval()
|
|
running_loss = 0.0
|
|
outputs = []
|
|
criterionMCM = nn.MSELoss()
|
|
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
input_ids = batch[0]
|
|
masked_tokens = batch[1]
|
|
masked_pos = batch[2]
|
|
|
|
logits_lm, output = model(input_ids, masked_pos)
|
|
|
|
output_batch_preproc = output
|
|
outputs.append(output_batch_preproc)
|
|
|
|
loss_lm = criterionMCM(logits_lm, masked_tokens)
|
|
loss = loss_lm/torch.var(masked_tokens)
|
|
running_loss += loss.item()
|
|
|
|
average_loss = running_loss / len(dataloader)
|
|
output_total = torch.cat(outputs, dim=0)
|
|
|
|
return average_loss, output_total
|
|
|
|
|
|
def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
|
|
labels = []
|
|
for scenario_idx in scenario_idxs:
|
|
scenario_name = scenarios_list()[scenario_idx]
|
|
|
|
data = deepmimo_data[scenario_idx]
|
|
labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
|
|
|
|
preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
|
|
|
|
return preprocessed_chs |