import subprocess |
import os |
import shutil |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import matplotlib.pyplot as plt |
from torch.optim import Adam |
from torch.optim.lr_scheduler import MultiStepLR |
from torch.utils.data import DataLoader |
import csv, json, time |
from sklearn.metrics import f1_score |
from tqdm import tqdm |
def clone_dataset_scenario(repo_url, model_repo_dir="./LWM", scenarios_dir="scenarios"): |
""" |
Clones all scenarios from a repository, ensuring all files (small and large) are downloaded. |
""" |
current_dir = os.path.basename(os.getcwd()) |
if current_dir == "LWM": |
model_repo_dir = "." |
scenarios_path = os.path.join(model_repo_dir, scenarios_dir) |
os.makedirs(scenarios_path, exist_ok=True) |
original_dir = os.getcwd() |
try: |
if os.path.exists(scenarios_path): |
shutil.rmtree(scenarios_path) |
print("Cloning entire repository into temporary directory ...") |
subprocess.run(["git", "clone", repo_url, scenarios_path], check=True) |
os.chdir(scenarios_path) |
print("Pulling all files using Git LFS ...") |
subprocess.run(["git", "lfs", "install"], check=True) |
subprocess.run(["git", "lfs", "pull"], check=True) |
print(f"Successfully cloned all scenarios into {scenarios_path}") |
except subprocess.CalledProcessError as e: |
print(f"Error cloning scenarios: {str(e)}") |
finally: |
if os.path.exists(scenarios_path): |
shutil.rmtree(scenarios_path) |
os.chdir(original_dir) |
model_repo_url = "https://huggingface.co/wi-lab/lwm" |
model_repo_dir = "./LWM" |
if not os.path.exists(model_repo_dir): |
print(f"Cloning model repository from {model_repo_url}...") |
subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True) |
dataset_repo_url = "https://huggingface.co/datasets/wi-lab/lwm" |
clone_dataset_scenario(dataset_repo_url, model_repo_dir) |
if os.path.exists(model_repo_dir): |
os.chdir(model_repo_dir) |
print(f"Changed working directory to {os.getcwd()}") |
else: |
print(f"Directory {model_repo_dir} does not exist. Please check if the repository is cloned properly.") |
from input_preprocess import tokenizer |
from lwm_model import lwm |
scenario_names = np.array([ |
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma", |
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego" |
]) |
scenario_idxs = np.array([0, 1, 2, 3, 4, 5]) |
selected_scenario_names = scenario_names[scenario_idxs] |
snr_db = None |
preprocessed_chs, deepmimo_data = tokenizer( |
selected_scenario_names=selected_scenario_names, |
manual_data=None, |
gen_raw=True, |
snr_db=snr_db |
) |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
print(f"Loading the LWM model on {device} ...") |
lwm_model = lwm.from_pretrained(device=device) |
from inference import lwm_inference, create_raw_dataset |
input_types = ['cls_emb', 'channel_emb', 'raw'] |
selected_input_type = input_types[2] |
if selected_input_type in ['cls_emb', 'channel_emb']: |
dataset = lwm_inference(preprocessed_chs, selected_input_type, lwm_model, device) |
else: |
dataset = create_raw_dataset(preprocessed_chs, device) |
mapping = { |
'cls_emb': {'input_channels': 1, 'sequence_length': 64}, |
'channel_emb': {'input_channels': 64, 'sequence_length': 128}, |
'raw': {'input_channels': 16, 'sequence_length': 128} |
} |
input_type = selected_input_type |
params = mapping.get(input_type, mapping['raw']) |
n_beams = 16 |
initial_lr = 0.001 |
num_classes = n_beams + 1 |
class ResidualBlock(nn.Module): |
def __init__(self, in_channels, out_channels): |
super(ResidualBlock, self).__init__() |
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) |
self.bn1 = nn.BatchNorm1d(out_channels) |
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
self.bn2 = nn.BatchNorm1d(out_channels) |
self.shortcut = nn.Sequential() |
if in_channels != out_channels: |
self.shortcut = nn.Sequential( |
nn.Conv1d(in_channels, out_channels, kernel_size=1), |
nn.BatchNorm1d(out_channels) |
) |
def forward(self, x): |
residual = x |
x = F.relu(self.bn1(self.conv1(x))) |
x = self.bn2(self.conv2(x)) |
x += self.shortcut(residual) |
x = F.relu(x) |
return x |
class res1dcnn(nn.Module): |
def __init__(self, input_channels, sequence_length, num_classes): |
super(res1dcnn, self).__init__() |
self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=7, stride=2, padding=3) |
self.bn1 = nn.BatchNorm1d(32) |
self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) |
self.layer1 = self._make_layer(32, 32, 2) |
self.layer2 = self._make_layer(32, 64, 3) |
self.layer3 = self._make_layer(64, 128, 4) |
with torch.no_grad(): |
dummy_input = torch.zeros(1, input_channels, sequence_length) |
dummy_output = self.compute_conv_output(dummy_input) |
self.flatten_size = dummy_output.numel() |
self.fc1 = nn.Linear(self.flatten_size, 128) |
self.bn_fc1 = nn.BatchNorm1d(128) |
self.fc2 = nn.Linear(128, num_classes) |
self.dropout = nn.Dropout(0.5) |
def _make_layer(self, in_channels, out_channels, num_blocks): |
layers = [ResidualBlock(in_channels, out_channels)] |
for _ in range(1, num_blocks): |
layers.append(ResidualBlock(out_channels, out_channels)) |
return nn.Sequential(*layers) |
def compute_conv_output(self, x): |
x = self.maxpool(F.relu(self.bn1(self.conv1(x)))) |
x = self.layer1(x) |
x = self.layer2(x) |
x = self.layer3(x) |
x = F.adaptive_avg_pool1d(x, 1) |
return x |
def forward(self, x): |
x = x.transpose(1, 2) |
x = self.compute_conv_output(x) |
x = x.view(x.size(0), -1) |
x = F.relu(self.bn_fc1(self.fc1(x))) |
x = self.dropout(x) |
x = self.fc2(x) |
return x |
beam_model = res1dcnn(params['input_channels'], params['sequence_length'], num_classes).to(device) |
optimizer = Adam(beam_model.parameters(), lr=initial_lr) |
scheduler = MultiStepLR(optimizer, milestones=[15, 35], gamma=0.1) |
num_epochs = 50 |
batch_size = 32 |
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
val_loader = DataLoader(dataset, batch_size=batch_size) |
test_loader = DataLoader(dataset, batch_size=batch_size) |
def plot_training_metrics(epochs, train_losses, val_losses, val_f1_scores, save_path=None): |
plt.figure(figsize=(12, 5)) |
plt.subplot(1, 2, 1) |
plt.plot(epochs, train_losses, label='Train Loss', marker='o') |
plt.plot(epochs, val_losses, label='Validation Loss', marker='o') |
plt.xlabel('Epoch') |
plt.ylabel('Loss') |
plt.title('Loss Curve') |
plt.legend() |
plt.subplot(1, 2, 2) |
plt.plot(epochs, val_f1_scores, label='Validation Weighted F1', marker='o', color='green') |
plt.xlabel('Epoch') |
plt.ylabel('Weighted F1 Score') |
plt.title('F1 Score Curve') |
plt.legend() |
plt.tight_layout() |
if save_path: |
plt.savefig(save_path) |
plt.show() |
criterion = nn.CrossEntropyLoss() |
train_losses = [] |
val_losses = [] |
val_f1_scores = [] |
epochs_list = [] |
for epoch in range(1, num_epochs + 1): |
beam_model.train() |
running_loss = 0.0 |
for data, target in tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False): |
data, target = data.to(device), target.to(device) |
if input_type == 'raw': |
data = data.view(data.size(0), params['sequence_length'], params['input_channels']) |
elif input_type == 'cls_emb': |
data = data.unsqueeze(2) |
optimizer.zero_grad() |
outputs = beam_model(data) |
loss = criterion(outputs, target) |
loss.backward() |
torch.nn.utils.clip_grad_norm_(beam_model.parameters(), max_norm=1.0) |
optimizer.step() |
running_loss += loss.item() * data.size(0) |
scheduler.step() |
train_loss = running_loss / len(train_loader.dataset) |
beam_model.eval() |
val_running_loss = 0.0 |
all_preds = [] |
all_targets = [] |
for data, target in tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False): |
data, target = data.to(device), target.to(device) |
if input_type == 'raw': |
data = data.view(data.size(0), params['sequence_length'], params['input_channels']) |
elif input_type == 'cls_emb': |
data = data.unsqueeze(2) |
outputs = beam_model(data) |
loss = criterion(outputs, target) |
val_running_loss += loss.item() * data.size(0) |
_, predicted = torch.max(outputs, 1) |
all_preds.extend(predicted.cpu().numpy()) |
all_targets.extend(target.cpu().numpy()) |
val_loss = val_running_loss / len(val_loader.dataset) |
val_f1 = f1_score(all_targets, all_preds, average='weighted') |
epochs_list.append(epoch) |
train_losses.append(train_loss) |
val_losses.append(val_loss) |
val_f1_scores.append(val_f1) |
print(f"Epoch {epoch}/{num_epochs}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Weighted F1: {val_f1:.4f}") |
plot_training_metrics(epochs_list, train_losses, val_losses, val_f1_scores, save_path="training_metrics.png") |
beam_model.eval() |
test_running_loss = 0.0 |
correct = 0 |
total = 0 |
for data, target in tqdm(test_loader, desc="Testing"): |
data, target = data.to(device), target.to(device) |
if input_type == 'raw': |
data = data.view(data.size(0), params['sequence_length'], params['input_channels']) |
elif input_type == 'cls_emb': |
data = data.unsqueeze(2) |
outputs = beam_model(data) |
loss = criterion(outputs, target) |
test_running_loss += loss.item() * data.size(0) |
_, predicted = torch.max(outputs, 1) |
total += target.size(0) |
correct += (predicted == target).sum().item() |
test_loss = test_running_loss / len(test_loader.dataset) |
accuracy = 100 * correct / total |
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.2f}%") |