|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import StepLR |
|
from inference import prepare_for_lwm |
|
from input_preprocess import tokenizer |
|
from lwm_model import lwm |
|
import numpy as np |
|
import DeepMIMOv3 |
|
|
|
|
|
def get_parameters(scenario): |
|
|
|
n_ant_bs = 32 |
|
n_ant_ue = 1 |
|
n_subcarriers = 32 |
|
scs = 30e3 |
|
|
|
row_column_users = { |
|
'asu_campus1': { |
|
'n_rows': 321, |
|
'n_per_row': 411 |
|
}, |
|
'Boston5G_3p5': { |
|
'n_rows': [812,1622], |
|
'n_per_row': 595 |
|
}, |
|
'city_0_newyork': { |
|
'n_rows': 44, |
|
'n_per_row': 117 |
|
}, |
|
'city_1_losangeles': { |
|
'n_rows': 57, |
|
'n_per_row': 81 |
|
}, |
|
'city_2_chicago': { |
|
'n_rows': 56, |
|
'n_per_row': 80 |
|
}, |
|
'city_3_houston': { |
|
'n_rows': 62, |
|
'n_per_row': 81 |
|
}, |
|
'city_4_phoenix': { |
|
'n_rows': 79, |
|
'n_per_row': 86 |
|
}, |
|
'city_5_philadelphia': { |
|
'n_rows': 96, |
|
'n_per_row': 66 |
|
}, |
|
'city_6_miami': { |
|
'n_rows': 80, |
|
'n_per_row': 87 |
|
}, |
|
'city_8_dallas': { |
|
'n_rows': 83, |
|
'n_per_row': 76 |
|
}, |
|
'city_9_sanfrancisco': { |
|
'n_rows': 79, |
|
'n_per_row': 83 |
|
}, |
|
'city_10_austin': { |
|
'n_rows': 102, |
|
'n_per_row': 55 |
|
}, |
|
'city_13_columbus': { |
|
'n_rows': 71, |
|
'n_per_row': 96 |
|
}, |
|
'city_17_seattle': { |
|
'n_rows': 74, |
|
'n_per_row': 82 |
|
}, |
|
'O1_3p5': { |
|
'n_rows': 5203, |
|
'n_per_row': 181 |
|
}, |
|
'city_18_denver': { |
|
'n_rows': 85, |
|
'n_per_row': 82 |
|
}, |
|
'city_15_indianapolis': { |
|
'n_rows': 80, |
|
'n_per_row': 79 |
|
}, |
|
'city_19_oklahoma': { |
|
'n_rows': 82, |
|
'n_per_row': 75 |
|
}, |
|
'city_12_fortworth': { |
|
'n_rows': 86, |
|
'n_per_row': 72 |
|
}, |
|
'city_11_santaclara': { |
|
'n_rows': 47, |
|
'n_per_row': 114 |
|
}, |
|
'city_7_sandiego': { |
|
'n_rows': 71, |
|
'n_per_row': 83 |
|
}} |
|
|
|
parameters = DeepMIMOv3.default_params() |
|
parameters['dataset_folder'] = './scenarios' |
|
parameters['scenario'] = scenario |
|
|
|
if scenario == 'O1_3p5': |
|
parameters['active_BS'] = np.array([4]) |
|
elif scenario in ['city_14_charlotte', 'city_18_denver', 'city_15_indianapolis']: |
|
parameters['active_BS'] = np.array([3]) |
|
else: |
|
parameters['active_BS'] = np.array([1]) |
|
|
|
if scenario == 'Boston5G_3p5': |
|
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0], |
|
row_column_users[scenario]['n_rows'][1]) |
|
else: |
|
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows']) |
|
parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) |
|
parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) |
|
parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1]) |
|
parameters['enable_BS2BS'] = False |
|
parameters['OFDM']['subcarriers'] = n_subcarriers |
|
parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers) |
|
|
|
parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9 |
|
parameters['num_paths'] = 20 |
|
|
|
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers |
|
|
|
|
|
n_epochs = 100 |
|
n_layers = 12 |
|
n_heads = 12 |
|
d_model = 64 |
|
d_ff = d_model * 4 |
|
d_k = d_model // n_heads |
|
d_v = d_model // n_heads |
|
dropout = 0.1 |
|
max_len = 129 |
|
element_length = 16 |
|
batch_size = 64 |
|
train_ratio = 0.7 |
|
val_ratio = 0.2 |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
|
|
scenario_names = np.array([ |
|
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma", |
|
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego" |
|
]) |
|
|
|
scenario_idxs = np.array([0, 1, 2, 3, 4, 5]) |
|
selected_scenario_names = scenario_names[scenario_idxs] |
|
|
|
preprocessed_chs = tokenizer( |
|
selected_scenario_names=selected_scenario_names, |
|
manual_data=None, |
|
gen_raw=False) |
|
|
|
|
|
train_size = int(train_ratio * len(preprocessed_chs)) |
|
val_size = int(val_ratio * len(preprocessed_chs)) |
|
test_size = len(preprocessed_chs) - val_size - train_size |
|
|
|
train_data, val_data, test_data = torch.utils.data.random_split( |
|
preprocessed_chs, [train_size, val_size, test_size] |
|
) |
|
|
|
train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True) |
|
val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True) |
|
test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
load_model = False |
|
|
|
model = lwm() |
|
model.to(device) |
|
|
|
if load_model: |
|
model_name = 'models/pretrained_model.pth' |
|
model.load_state_dict(torch.load(model_name)) |
|
print(f"Model loaded from {model_name}") |
|
|
|
|
|
criterionMLM = nn.MSELoss() |
|
|
|
|
|
adaptive_lr = False |
|
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5) |
|
scheduler = ( |
|
optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min') |
|
if adaptive_lr |
|
else StepLR(optimizer, step_size=10, gamma=0.9) |
|
) |
|
|
|
|
|
training_loss = [] |
|
validation_loss = [] |
|
|
|
def train(model, dataloader, optimizer, scheduler=None, device="cuda"): |
|
|
|
model.train() |
|
running_loss = 0.0 |
|
criterionMCM = nn.MSELoss() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
input_ids = batch[0].to(device) |
|
masked_tokens = batch[1].to(device) |
|
masked_pos = batch[2].to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
logits_lm, _ = model(input_ids, masked_pos) |
|
loss_lm = criterionMCM(logits_lm, masked_tokens) |
|
loss = loss_lm / torch.var(masked_tokens) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
running_loss += loss.item() |
|
|
|
average_loss = running_loss / len(dataloader) |
|
|
|
return average_loss |
|
|
|
def validate(model, dataloader, device="cuda"): |
|
model.eval() |
|
running_loss = 0.0 |
|
criterionMCM = nn.MSELoss() |
|
|
|
with torch.no_grad(): |
|
for idx, batch in enumerate(dataloader): |
|
input_ids = batch[0].to(device) |
|
masked_tokens = batch[1].to(device) |
|
masked_pos = batch[2].to(device) |
|
|
|
logits_lm, _ = model(input_ids, masked_pos) |
|
|
|
loss_lm = criterionMCM(logits_lm, masked_tokens) |
|
loss = loss_lm / torch.var(masked_tokens) |
|
|
|
running_loss += loss.item() |
|
|
|
average_loss = running_loss / len(dataloader) |
|
|
|
return average_loss |
|
|
|
|
|
for epoch in range(n_epochs): |
|
print(f"Epoch {epoch + 1}/{n_epochs}") |
|
|
|
|
|
train_loss = train(model, train_loader, optimizer, scheduler, device) |
|
training_loss.append(train_loss) |
|
print(f"Training Loss: {train_loss:.4f}") |
|
|
|
|
|
if val_loader is not None: |
|
val_loss = validate(model, val_loader, device) |
|
validation_loss.append(val_loss) |
|
print(f"Validation Loss: {val_loss:.4f}") |
|
|