lwm / utils /pretraining.py
wi-lab's picture
Update utils/pretraining.py
0431525 verified
#%% PACKAGES & MODULES
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from inference import prepare_for_lwm
from input_preprocess import tokenizer
from lwm_model import lwm
import numpy as np
import DeepMIMOv3
#%% PRE-TRAINING SCENARIO CONFIG
def get_parameters(scenario):
n_ant_bs = 32
n_ant_ue = 1
n_subcarriers = 32
scs = 30e3
row_column_users = {
'asu_campus1': {
'n_rows': 321,
'n_per_row': 411
},
'Boston5G_3p5': {
'n_rows': [812,1622],
'n_per_row': 595
},
'city_0_newyork': {
'n_rows': 44,
'n_per_row': 117
},
'city_1_losangeles': {
'n_rows': 57,
'n_per_row': 81
},
'city_2_chicago': {
'n_rows': 56,
'n_per_row': 80
},
'city_3_houston': {
'n_rows': 62,
'n_per_row': 81
},
'city_4_phoenix': {
'n_rows': 79,
'n_per_row': 86
},
'city_5_philadelphia': {
'n_rows': 96,
'n_per_row': 66
},
'city_6_miami': {
'n_rows': 80,
'n_per_row': 87
},
'city_8_dallas': {
'n_rows': 83,
'n_per_row': 76
},
'city_9_sanfrancisco': {
'n_rows': 79,
'n_per_row': 83
},
'city_10_austin': {
'n_rows': 102,
'n_per_row': 55
},
'city_13_columbus': {
'n_rows': 71,
'n_per_row': 96
},
'city_17_seattle': {
'n_rows': 74,
'n_per_row': 82
},
'O1_3p5': {
'n_rows': 5203,
'n_per_row': 181
},
'city_18_denver': {
'n_rows': 85,
'n_per_row': 82
},
'city_15_indianapolis': {
'n_rows': 80,
'n_per_row': 79
},
'city_19_oklahoma': {
'n_rows': 82,
'n_per_row': 75
},
'city_12_fortworth': {
'n_rows': 86,
'n_per_row': 72
},
'city_11_santaclara': {
'n_rows': 47,
'n_per_row': 114
},
'city_7_sandiego': {
'n_rows': 71,
'n_per_row': 83
}}
parameters = DeepMIMOv3.default_params()
parameters['dataset_folder'] = './scenarios'
parameters['scenario'] = scenario
if scenario == 'O1_3p5':
parameters['active_BS'] = np.array([4])
elif scenario in ['city_14_charlotte', 'city_18_denver', 'city_15_indianapolis']:
parameters['active_BS'] = np.array([3])
else:
parameters['active_BS'] = np.array([1])
if scenario == 'Boston5G_3p5':
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
row_column_users[scenario]['n_rows'][1])
else:
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
parameters['enable_BS2BS'] = False
parameters['OFDM']['subcarriers'] = n_subcarriers
parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
parameters['num_paths'] = 20
return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
#%% PARAMETERS
n_epochs = 100
n_layers = 12
n_heads = 12
d_model = 64
d_ff = d_model * 4
d_k = d_model // n_heads
d_v = d_model // n_heads
dropout = 0.1
max_len = 129
element_length = 16
batch_size = 64
train_ratio = 0.7
val_ratio = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#%% PRE-TRAINING DATA GENERATION
# The following DeepMIMO scenarios are not enough for pre-training a
# Transformer-based foundation model like LWM. Add more scenarios for
# more effective pre-training. The instruction for reproducing the actual
# dataset used for pre-training LWM can be found in the Huggingface forum.
scenario_names = np.array([
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
])
scenario_idxs = np.array([0, 1, 2, 3, 4, 5])
selected_scenario_names = scenario_names[scenario_idxs]
preprocessed_chs = tokenizer(
selected_scenario_names=selected_scenario_names,
manual_data=None,
gen_raw=False)
#%% DATALOADER
train_size = int(train_ratio * len(preprocessed_chs))
val_size = int(val_ratio * len(preprocessed_chs))
test_size = len(preprocessed_chs) - val_size - train_size
train_data, val_data, test_data = torch.utils.data.random_split(
preprocessed_chs, [train_size, val_size, test_size]
)
train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True)
val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True)
test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True)
# %% Model
load_model = False
model = lwm()
model.to(device)
if load_model:
model_name = 'models/pretrained_model.pth'
model.load_state_dict(torch.load(model_name))
print(f"Model loaded from {model_name}")
# Loss function
criterionMLM = nn.MSELoss()
# %% Optimizer and Scheduler
adaptive_lr = False
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = (
optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
if adaptive_lr
else StepLR(optimizer, step_size=10, gamma=0.9)
)
# %% Training
training_loss = []
validation_loss = []
def train(model, dataloader, optimizer, scheduler=None, device="cuda"):
model.train()
running_loss = 0.0
criterionMCM = nn.MSELoss()
for idx, batch in enumerate(dataloader):
input_ids = batch[0].to(device)
masked_tokens = batch[1].to(device)
masked_pos = batch[2].to(device)
optimizer.zero_grad()
logits_lm, _ = model(input_ids, masked_pos)
loss_lm = criterionMCM(logits_lm, masked_tokens)
loss = loss_lm / torch.var(masked_tokens)
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
running_loss += loss.item()
average_loss = running_loss / len(dataloader)
return average_loss
def validate(model, dataloader, device="cuda"):
model.eval()
running_loss = 0.0
criterionMCM = nn.MSELoss()
with torch.no_grad():
for idx, batch in enumerate(dataloader):
input_ids = batch[0].to(device)
masked_tokens = batch[1].to(device)
masked_pos = batch[2].to(device)
logits_lm, _ = model(input_ids, masked_pos)
loss_lm = criterionMCM(logits_lm, masked_tokens)
loss = loss_lm / torch.var(masked_tokens)
running_loss += loss.item()
average_loss = running_loss / len(dataloader)
return average_loss
# %% Training Loop
for epoch in range(n_epochs):
print(f"Epoch {epoch + 1}/{n_epochs}")
# Training step
train_loss = train(model, train_loader, optimizer, scheduler, device)
training_loss.append(train_loss)
print(f"Training Loss: {train_loss:.4f}")
# Validation step
if val_loader is not None:
val_loss = validate(model, val_loader, device)
validation_loss.append(val_loss)
print(f"Validation Loss: {val_loss:.4f}")