Sadjad Alikhani commited on
Commit
ebfb25d
·
verified ·
1 Parent(s): 071c25a

upload required files

Browse files
Files changed (7) hide show
  1. input_preprocess.py +348 -0
  2. lwm_model.py +173 -0
  3. model.py +29 -0
  4. model_weights.pth +3 -0
  5. save_model.py +29 -0
  6. tokenizer.py +33 -0
  7. upload_to_huggingface.py +17 -0
input_preprocess.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import dataset_gen
13
+ import dataset_utils as dt
14
+ import os
15
+ from tqdm import tqdm
16
+ import time
17
+ import pickle
18
+ import DeepMIMOv3
19
+
20
+ vars_folder = 'variables/'
21
+ os.makedirs(vars_folder, exist_ok=True)
22
+
23
+ #%% Scenarios List
24
+ def scenarios_list():
25
+ """Returns an array of available scenarios."""
26
+ return np.array([
27
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
28
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
29
+ ])
30
+
31
+
32
+ #%% Token Generation
33
+ def gen_tokens(scenario_idxs, patch_gen=True, patch_size=16,
34
+ gen_deepMIMO_data=False, gen_raw=False, save_data=False):
35
+ """
36
+ Generates tokens by preparing and preprocessing the dataset.
37
+
38
+ Args:
39
+ scenario_idxs (list): Indices of the scenarios.
40
+ patch_gen (bool): Whether to generate patches. Defaults to True.
41
+ patch_size (int): Size of each patch. Defaults to 16.
42
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
43
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
44
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
45
+
46
+ Returns:
47
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
48
+ """
49
+
50
+ vars_folder = 'variables/'
51
+ os.makedirs(vars_folder, exist_ok=True)
52
+
53
+ # Fetch scenarios
54
+ scenario_list = scenarios_list()
55
+ scenarios = scenario_list[scenario_idxs] if len(scenario_idxs) > 1 else [scenario_list[scenario_idxs[0]]]
56
+
57
+ # Patch generation or loading
58
+ if patch_gen:
59
+ patches = [patch_makerv2(patch_size=patch_size, scenario=scenario,
60
+ gen_deepMIMO_data=gen_deepMIMO_data, save_patches=False,
61
+ save_file_name=f'patch_{scenario}.p',
62
+ norm_factor=1e6, save_data=save_data)
63
+ for scenario in scenarios]
64
+ patches = np.vstack(patches)
65
+ else:
66
+ patches = [dt.load_var(vars_folder + 'patch_{scenario}.p')
67
+ for scenario in scenarios]
68
+ patches = np.vstack(patches)
69
+
70
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
71
+
72
+ # Define dimensions
73
+ patch_size = patches.shape[2]
74
+ n_patches = patches.shape[1]
75
+ n_masks_half = int(0.15 * n_patches / 2)
76
+ sequence_length = n_patches + 1
77
+ element_length = patch_size
78
+
79
+ # Generate preprocessed data
80
+ preprocessed_data = []
81
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
82
+ sample = make_samplev2(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
83
+ preprocessed_data.append(sample)
84
+
85
+ if save_data:
86
+ dt.save_var(preprocessed_data, vars_folder + 'preprocessed_data.p')
87
+
88
+ return preprocessed_data, sequence_length, element_length
89
+
90
+
91
+ #%% Patch Creation
92
+ def patch_makerv2(patch_size=16, scenario=None, gen_deepMIMO_data=False,
93
+ save_patches=True, save_file_name=None, norm_factor=1,
94
+ save_data=False):
95
+ """
96
+ Creates patches from the dataset based on the scenario.
97
+
98
+ Args:
99
+ patch_size (int): Size of each patch.
100
+ scenario (str): Selected scenario for data generation.
101
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
102
+ save_patches (bool): Whether to save generated patches.
103
+ save_file_name (str): Filename for saving patches.
104
+ norm_factor (int): Normalization factor for channels.
105
+ save_data (bool): Whether to save data.
106
+
107
+ Returns:
108
+ patch (numpy array): Generated patches.
109
+ """
110
+
111
+ vars_folder = 'variables/'
112
+ os.makedirs(vars_folder, exist_ok=True)
113
+
114
+ data = DeepMIMO_data_gen(scenario, gen_deepMIMO_data, save_data=save_data)
115
+ idxs = np.where(data['user']['LoS'] != -1)[0]
116
+
117
+ # Reshaping and normalizing channels
118
+ original_ch = data['user']['channel'][idxs]
119
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
120
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag)) * norm_factor
121
+
122
+ # Create patches
123
+ n_patches = flat_channels_complex.shape[1] // patch_size
124
+ patch = np.zeros((len(idxs), n_patches, patch_size))
125
+ for idx in range(n_patches):
126
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
127
+
128
+ if save_patches:
129
+ dt.save_var(patch, vars_folder + save_file_name)
130
+
131
+ return patch
132
+
133
+
134
+ #%% Data Generation for Scenario Areas
135
+ def DeepMIMO_data_gen(scenario, gen_deepMIMO_data, save_data=False):
136
+ """
137
+ Generates or loads data for a given scenario.
138
+
139
+ Args:
140
+ scenario (str): Scenario name.
141
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
142
+ save_data (bool): Whether to save generated data.
143
+
144
+ Returns:
145
+ data (dict): Loaded or generated data.
146
+ """
147
+
148
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
149
+
150
+ if gen_deepMIMO_data:
151
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
152
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
153
+ users_per_row=row_column_users[scenario]['n_per_row'])
154
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
155
+
156
+ if save_data:
157
+ save_var(data, vars_folder + f'data_{scenario}_{n_ant_bs}_{n_ant_ue}_{n_subcarriers}.p')
158
+ else:
159
+ data = load_var(vars_folder + f'data_{scenario}_{n_ant_bs}_{n_ant_ue}_{n_subcarriers}.p')
160
+
161
+ return data
162
+
163
+ #%%%
164
+ def get_parameters(scenario):
165
+
166
+ n_ant_bs = 32 #32
167
+ n_ant_ue = 1
168
+ n_subcarriers = 32 #32
169
+ scs = 30e3
170
+
171
+ row_column_users = {
172
+ 'city_18_denver': {
173
+ 'n_rows': 85,
174
+ 'n_per_row': 82
175
+ },
176
+ 'city_15_indianapolis': {
177
+ 'n_rows': 80,
178
+ 'n_per_row': 79
179
+ },
180
+ 'city_19_oklahoma': {
181
+ 'n_rows': 82,
182
+ 'n_per_row': 75
183
+ },
184
+ 'city_12_fortworth': {
185
+ 'n_rows': 86,
186
+ 'n_per_row': 72
187
+ },
188
+ 'city_11_santaclara': {
189
+ 'n_rows': 47,
190
+ 'n_per_row': 114
191
+ },
192
+ 'city_7_sandiego': {
193
+ 'n_rows': 71,
194
+ 'n_per_row': 83
195
+ }}
196
+
197
+ parameters = DeepMIMOv3.default_params()
198
+ parameters['dataset_folder'] = './scenarios'
199
+ parameters['scenario'] = scenario
200
+
201
+ if scenario == 'O1_3p5':
202
+ parameters['active_BS'] = np.array([4])
203
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
204
+ parameters['active_BS'] = np.array([3])
205
+ else:
206
+ parameters['active_BS'] = np.array([1])
207
+
208
+ if scenario == 'Boston5G_3p5':
209
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
210
+ row_column_users[scenario]['n_rows'][1])
211
+ else:
212
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
213
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
214
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
215
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
216
+ parameters['enable_BS2BS'] = False
217
+ parameters['OFDM']['subcarriers'] = n_subcarriers
218
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
219
+
220
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
221
+ parameters['num_paths'] = 20
222
+
223
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
224
+
225
+
226
+ #%% Sample Generation
227
+ def make_samplev2(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
228
+ """
229
+ Generates a sample for each user, including masking and tokenizing.
230
+
231
+ Args:
232
+ user_idx (int): Index of the user.
233
+ patch (numpy array): Patches data.
234
+ word2id (dict): Dictionary for special tokens.
235
+ n_patches (int): Number of patches.
236
+ n_masks (int): Number of masks.
237
+ patch_size (int): Size of each patch.
238
+ gen_raw (bool): Whether to generate raw tokens.
239
+
240
+ Returns:
241
+ sample (list): Generated sample for the user.
242
+ """
243
+
244
+ tokens = patch[user_idx]
245
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
246
+
247
+ real_tokens_size = int(n_patches / 2)
248
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
249
+ masks_pos_imag = masks_pos_real + real_tokens_size
250
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
251
+
252
+ masked_tokens = []
253
+ for pos in masked_pos:
254
+ original_masked_tokens = input_ids[pos].copy()
255
+ masked_tokens.append(original_masked_tokens)
256
+ if not gen_raw:
257
+ rnd_num = np.random.rand()
258
+ if rnd_num < 0.1:
259
+ input_ids[pos] = np.random.rand(patch_size)
260
+ elif rnd_num < 0.9:
261
+ input_ids[pos] = word2id['[MASK]']
262
+
263
+ return [input_ids, masked_tokens, masked_pos]
264
+
265
+
266
+ #%% Sampling and Data Selection
267
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
268
+ """
269
+ Performs uniform sampling on the dataset.
270
+
271
+ Args:
272
+ dataset (dict): DeepMIMO dataset.
273
+ sampling_div (list): Step sizes along [x, y] dimensions.
274
+ n_rows (int): Number of rows for user selection.
275
+ users_per_row (int): Number of users per row.
276
+
277
+ Returns:
278
+ uniform_idxs (numpy array): Indices of the selected samples.
279
+ """
280
+ cols = np.arange(users_per_row, step=sampling_div[0])
281
+ rows = np.arange(n_rows, step=sampling_div[1])
282
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
283
+
284
+ return uniform_idxs
285
+
286
+ def select_by_idx(dataset, idxs):
287
+ """
288
+ Selects a subset of the dataset based on the provided indices.
289
+
290
+ Args:
291
+ dataset (dict): Dataset to trim.
292
+ idxs (numpy array): Indices of users to select.
293
+
294
+ Returns:
295
+ dataset_t (list): Trimmed dataset based on selected indices.
296
+ """
297
+ dataset_t = [] # Trimmed dataset
298
+ for bs_idx in range(len(dataset)):
299
+ dataset_t.append({})
300
+ for key in dataset[bs_idx].keys():
301
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
302
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
303
+
304
+ return dataset_t
305
+
306
+ #%% Save and Load Utilities
307
+ def save_var(var, path):
308
+ """
309
+ Saves a variable to a pickle file.
310
+
311
+ Args:
312
+ var (object): Variable to be saved.
313
+ path (str): Path to save the file.
314
+
315
+ Returns:
316
+ None
317
+ """
318
+ path_full = path if path.endswith('.p') else (path + '.pickle')
319
+ with open(path_full, 'wb') as handle:
320
+ pickle.dump(var, handle)
321
+
322
+ def load_var(path):
323
+ """
324
+ Loads a variable from a pickle file.
325
+
326
+ Args:
327
+ path (str): Path of the file to load.
328
+
329
+ Returns:
330
+ var (object): Loaded variable.
331
+ """
332
+ path_full = path if path.endswith('.p') else (path + '.pickle')
333
+ with open(path_full, 'rb') as handle:
334
+ var = pickle.load(handle)
335
+
336
+ return var
337
+
338
+ #%%
339
+ scenario_idxs = [0, 1, 2, 3, 4, 5]
340
+
341
+ preprocessed_data, max_len, element_length = gen_tokens(scenario_idxs,
342
+ patch_gen=True,
343
+ patch_size=16,
344
+ gen_deepMIMO_data=True,
345
+ gen_raw=True,
346
+ save_data=False)
347
+
348
+ #%%
lwm_model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LWM (Large Wireless Model) Implementation and Loading
3
+
4
+ @author: salikha4
5
+
6
+ This module defines a Large Wireless Model (LWM) using PyTorch, including custom layers
7
+ for embedding, self-attention, and feed-forward networks. It also provides functionality
8
+ to load a pre-trained model from a checkpoint.
9
+
10
+ Dependencies:
11
+ - torch
12
+ - numpy
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+ # Constants
21
+ ELEMENT_LENGTH = 16
22
+ MAX_LEN = 129
23
+ N_LAYERS = 12
24
+ N_HEADS = 12
25
+ D_MODEL = 64
26
+ D_FF = D_MODEL * 4
27
+ D_K = D_MODEL // N_HEADS
28
+ D_V = D_MODEL // N_HEADS
29
+ DROPOUT = 0.1
30
+
31
+ class LayerNormalization(nn.Module):
32
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.eps = eps
35
+ self.alpha = nn.Parameter(torch.ones(d_model))
36
+ self.bias = nn.Parameter(torch.zeros(d_model))
37
+
38
+ def forward(self, x):
39
+ mean = x.mean(dim=-1, keepdim=True)
40
+ std = x.std(dim=-1, keepdim=True)
41
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
42
+
43
+ class Embedding(nn.Module):
44
+ def __init__(self, element_length, d_model, max_len):
45
+ super().__init__()
46
+ self.element_length = element_length
47
+ self.d_model = d_model
48
+ self.proj = nn.Linear(element_length, d_model)
49
+ self.pos_embed = nn.Embedding(max_len, d_model)
50
+ self.norm = LayerNormalization(d_model)
51
+
52
+ def forward(self, x):
53
+ seq_len = x.size(1)
54
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
55
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
56
+ tok_emb = self.proj(x.float())
57
+ embedding = tok_emb + self.pos_embed(pos)
58
+ return self.norm(embedding)
59
+
60
+ class ScaledDotProductAttention(nn.Module):
61
+ def __init__(self):
62
+ super().__init__()
63
+
64
+ def forward(self, Q, K, V):
65
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
66
+ attn = F.softmax(scores, dim=-1)
67
+ context = torch.matmul(attn, V)
68
+ return context, attn
69
+
70
+ class MultiHeadAttention(nn.Module):
71
+ def __init__(self):
72
+ super().__init__()
73
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
74
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
75
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
76
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
77
+ self.norm = LayerNormalization(D_MODEL)
78
+ self.dropout = nn.Dropout(DROPOUT)
79
+
80
+ def forward(self, Q, K, V):
81
+ residual, batch_size = Q, Q.size(0)
82
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
83
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
84
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
85
+
86
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
87
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
88
+ output = self.linear(output)
89
+ return residual + self.dropout(output), attn #residual + self.dropout(output), attn
90
+
91
+ class PoswiseFeedForwardNet(nn.Module):
92
+ def __init__(self):
93
+ super().__init__()
94
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
95
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
96
+ self.dropout = nn.Dropout(DROPOUT)
97
+ self.norm = LayerNormalization(D_MODEL)
98
+
99
+ def forward(self, x):
100
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
101
+ return x + self.dropout(output) #x + self.dropout(output)
102
+
103
+ class EncoderLayer(nn.Module):
104
+ def __init__(self):
105
+ super().__init__()
106
+ self.enc_self_attn = MultiHeadAttention()
107
+ self.pos_ffn = PoswiseFeedForwardNet()
108
+ self.norm = LayerNormalization(D_MODEL)
109
+
110
+ def forward(self, enc_inputs):
111
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
112
+ attn_outputs = self.norm(attn_outputs)
113
+ enc_outputs = self.pos_ffn(attn_outputs)
114
+ return enc_outputs, attn
115
+
116
+ class LWM(nn.Module):
117
+ def __init__(self, element_length, d_model, max_len, n_layers):
118
+ super().__init__()
119
+ self.embedding = Embedding(element_length, d_model, max_len)
120
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
121
+ self.linear = nn.Linear(d_model, d_model)
122
+ self.norm = LayerNormalization(d_model)
123
+
124
+ embed_weight = self.embedding.proj.weight
125
+ d_model, n_dim = embed_weight.size()
126
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
127
+ self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
128
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
129
+
130
+ def forward(self, input_ids, masked_pos):
131
+ output = self.embedding(input_ids)
132
+
133
+ for layer in self.layers:
134
+ output, _ = layer(output)
135
+
136
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
137
+ h_masked = torch.gather(output, 1, masked_pos)
138
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
139
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
140
+
141
+ return logits_lm, output
142
+
143
+ def load_model(model, model_path, device=None):
144
+ """
145
+ Load a pre-trained LWM model from a checkpoint.
146
+
147
+ Args:
148
+ model_path (str): Path to the checkpoint file.
149
+ device (torch.device, optional): Device to load the model onto.
150
+
151
+ Returns:
152
+ LWM: Loaded model instance.
153
+ """
154
+ if device is None:
155
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
156
+
157
+ #model = LWM(ELEMENT_LENGTH, D_MODEL, MAX_LEN, N_LAYERS)
158
+ state_dict = torch.load(model_path, map_location=device)
159
+ model.load_state_dict(state_dict)
160
+ model.to(device)
161
+ return model
162
+
163
+ # Usage example
164
+ if __name__ == "__main__":
165
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
+ model_name = 'model_weights.pth'
167
+ model_path = f'huggingFace/{model_name}'
168
+
169
+ model = LWM(ELEMENT_LENGTH, D_MODEL, MAX_LEN, N_LAYERS)
170
+
171
+ model = load_model(model, model_path, device)
172
+ print(f"Model loaded successfully on {device}")
173
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:16:12 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from lwm_model import LWM
10
+
11
+ class WirelessConfig(PretrainedConfig):
12
+ model_type = "lwm"
13
+
14
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12, **kwargs):
15
+ super().__init__(**kwargs)
16
+ self.element_length = element_length
17
+ self.d_model = d_model
18
+ self.max_len = max_len
19
+ self.n_layers = n_layers
20
+
21
+ class WirelessChannelModel(PreTrainedModel):
22
+ config_class = WirelessConfig
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.lwm = LWM(config.element_length, config.d_model, config.max_len, config.n_layers)
27
+
28
+ def forward(self, input_ids, masked_pos):
29
+ return self.lwm(input_ids, masked_pos)
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:838d9f35e9e1bfd46e4c3212d00fa069f5ea02a93c0f807d25399e755b2eebbc
3
+ size 2509918
save_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:16:37 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import torch
9
+ from tokenizer import WirelessChannelTokenizer
10
+ from model import WirelessChannelModel, WirelessConfig
11
+
12
+ model_name = 'model_weights.pth'
13
+ model_path = f'huggingFace/{model_name}'
14
+
15
+ # Initialize model config
16
+ config = WirelessConfig()
17
+
18
+ # Initialize the model
19
+ model = WirelessChannelModel(config)
20
+
21
+ # Load pretrained weights
22
+ model.load_state_dict(torch.load(model_path))
23
+
24
+ # Initialize tokenizer (preprocessor)
25
+ tokenizer = WirelessChannelTokenizer(patch_size=16, max_len=129)
26
+
27
+ # Save the model and tokenizer for Hugging Face
28
+ model.save_pretrained("huggingFace/")
29
+ tokenizer.save_pretrained("huggingFace/")
tokenizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:15:23 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ from transformers import PreTrainedTokenizer
9
+ from input_preprocess import gen_tokens
10
+
11
+ class WirelessChannelTokenizer(PreTrainedTokenizer):
12
+ """
13
+ A Hugging Face-compatible tokenizer for wireless channels.
14
+ It performs segmentation and masking for wireless channel data.
15
+ """
16
+ def __init__(self, patch_size=16, max_len=129, **kwargs):
17
+ super().__init__(**kwargs)
18
+ self.patch_size = patch_size
19
+ self.max_len = max_len
20
+
21
+ def preprocess_channels(self, scenario_idxs):
22
+ # Call gen_tokens() for preprocessing the wireless channel data
23
+ preprocessed_data, sequence_length, element_length = gen_tokens(
24
+ scenario_idxs, patch_gen=True, patch_size=self.patch_size,
25
+ gen_deepMIMO_data=True, gen_raw=True, save_data=False
26
+ )
27
+ return preprocessed_data
28
+
29
+ def __call__(self, scenario_idxs):
30
+ return self.preprocess_channels(scenario_idxs)
31
+
32
+ def save_pretrained(self, save_directory):
33
+ super().save_pretrained(save_directory)
upload_to_huggingface.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 19:31:13 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ from huggingface_hub import HfApi
9
+
10
+ api = HfApi()
11
+
12
+ # Upload the folder containing both model and tokenizer
13
+ api.upload_folder(
14
+ folder_path="path_to_save_model", # The folder containing saved model/tokenizer
15
+ repo_id="your_username/your_model_name", # Your Hugging Face username and model name
16
+ private=True # Set to True for private, False for public
17
+ )