Sadjad Alikhani commited on
Commit
43db74c
·
verified ·
1 Parent(s): c4cd4f6

Upload input_preprocess.py

Browse files
Files changed (1) hide show
  1. input_preprocess.py +310 -0
input_preprocess.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Fri Sep 13 16:13:29 2024
4
+
5
+ This script generates preprocessed data from wireless communication scenarios,
6
+ including token generation, patch creation, and data sampling for machine learning models.
7
+
8
+ @author: salikha4
9
+ """
10
+
11
+ import numpy as np
12
+ import os
13
+ from tqdm import tqdm
14
+ import time
15
+ import pickle
16
+ import DeepMIMOv3
17
+
18
+ #%% Scenarios List
19
+ def scenarios_list():
20
+ """Returns an array of available scenarios."""
21
+ return np.array([
22
+ 'city_18_denver', 'city_15_indianapolis', 'city_19_oklahoma',
23
+ 'city_12_fortworth', 'city_11_santaclara', 'city_7_sandiego'
24
+ ])
25
+
26
+ #%% Token Generation
27
+ def tokenizer(selected_scenario_names=None, manual_data=None, gen_raw=True):
28
+ """
29
+ Generates tokens by preparing and preprocessing the dataset.
30
+
31
+ Args:
32
+ scenario_idxs (list): Indices of the scenarios.
33
+ patch_gen (bool): Whether to generate patches. Defaults to True.
34
+ patch_size (int): Size of each patch. Defaults to 16.
35
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data. Defaults to False.
36
+ gen_raw (bool): Whether to generate raw data. Defaults to False.
37
+ save_data (bool): Whether to save the preprocessed data. Defaults to False.
38
+
39
+ Returns:
40
+ preprocessed_data, sequence_length, element_length: Preprocessed data and related dimensions.
41
+ """
42
+
43
+ if manual_data is not None:
44
+ patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1))
45
+ else:
46
+ # Patch generation or loading
47
+ deepmimo_data = [DeepMIMO_data_gen(scenario_name) for scenario_name in selected_scenario_names]
48
+ n_scenarios = len(selected_scenario_names)
49
+
50
+ cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
51
+
52
+ print(len(cleaned_deepmimo_data))
53
+ print(len(cleaned_deepmimo_data[0]))
54
+ print(len(cleaned_deepmimo_data[0][0]))
55
+
56
+ patches = [patch_maker(cleaned_deepmimo_data[scenario_idx]) for scenario_idx in range(n_scenarios)]
57
+ patches = np.vstack(patches)
58
+
59
+ # Define dimensions
60
+ patch_size = patches.shape[2]
61
+ n_patches = patches.shape[1]
62
+ n_masks_half = int(0.15 * n_patches / 2)
63
+ sequence_length = n_patches + 1
64
+ element_length = patch_size
65
+
66
+ word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
67
+
68
+ # Generate preprocessed channels
69
+ preprocessed_data = []
70
+ for user_idx in tqdm(range(len(patches)), desc="Processing items"):
71
+ sample = make_sample(user_idx, patches, word2id, n_patches, n_masks_half, patch_size, gen_raw=gen_raw)
72
+ preprocessed_data.append(sample)
73
+
74
+ return preprocessed_data
75
+
76
+ #%%
77
+ def deepmimo_data_cleaning(deepmimo_data):
78
+ idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
79
+ cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
80
+ return cleaned_deepmimo_data
81
+
82
+ #%% Patch Creation
83
+ def patch_maker(original_ch, patch_size=16, norm_factor=1e6):
84
+ """
85
+ Creates patches from the dataset based on the scenario.
86
+
87
+ Args:-
88
+ patch_size (int): Size of each patch.
89
+ scenario (str): Selected scenario for data generation.
90
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
91
+ norm_factor (int): Normalization factor for channels.
92
+
93
+ Returns:
94
+ patch (numpy array): Generated patches.
95
+ """
96
+ # idxs = np.where(data['user']['LoS'] != -1)[0]
97
+
98
+ # # Reshaping and normalizing channels
99
+ # original_ch = data['user']['channel'][idxs]
100
+ flat_channels = original_ch.reshape((original_ch.shape[0], -1)).astype(np.csingle)
101
+ flat_channels_complex = np.hstack((flat_channels.real, flat_channels.imag)) * norm_factor
102
+
103
+ # Create patches
104
+ n_patches = flat_channels_complex.shape[1] // patch_size
105
+ patch = np.zeros((len(flat_channels_complex), n_patches, patch_size))
106
+ for idx in range(n_patches):
107
+ patch[:, idx, :] = flat_channels_complex[:, idx * patch_size:(idx + 1) * patch_size]
108
+
109
+ return patch
110
+
111
+
112
+ #%% Data Generation for Scenario Areas
113
+ def DeepMIMO_data_gen(scenario):
114
+ """
115
+ Generates or loads data for a given scenario.
116
+
117
+ Args:
118
+ scenario (str): Scenario name.
119
+ gen_deepMIMO_data (bool): Whether to generate DeepMIMO data.
120
+ save_data (bool): Whether to save generated data.
121
+
122
+ Returns:
123
+ data (dict): Loaded or generated data.
124
+ """
125
+
126
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
127
+
128
+ deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
129
+ uniform_idxs = uniform_sampling(deepMIMO_dataset, [1, 1], len(parameters['user_rows']),
130
+ users_per_row=row_column_users[scenario]['n_per_row'])
131
+ data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
132
+
133
+ return data
134
+
135
+ #%%%
136
+ def get_parameters(scenario):
137
+
138
+ n_ant_bs = 32 #32
139
+ n_ant_ue = 1
140
+ n_subcarriers = 32 #32
141
+ scs = 30e3
142
+
143
+ row_column_users = {
144
+ 'city_18_denver': {
145
+ 'n_rows': 85,
146
+ 'n_per_row': 82
147
+ },
148
+ 'city_15_indianapolis': {
149
+ 'n_rows': 80,
150
+ 'n_per_row': 79
151
+ },
152
+ 'city_19_oklahoma': {
153
+ 'n_rows': 82,
154
+ 'n_per_row': 75
155
+ },
156
+ 'city_12_fortworth': {
157
+ 'n_rows': 86,
158
+ 'n_per_row': 72
159
+ },
160
+ 'city_11_santaclara': {
161
+ 'n_rows': 47,
162
+ 'n_per_row': 114
163
+ },
164
+ 'city_7_sandiego': {
165
+ 'n_rows': 71,
166
+ 'n_per_row': 83
167
+ }}
168
+
169
+ parameters = DeepMIMOv3.default_params()
170
+ parameters['dataset_folder'] = './scenarios'
171
+ parameters['scenario'] = scenario
172
+
173
+ if scenario == 'O1_3p5':
174
+ parameters['active_BS'] = np.array([4])
175
+ elif scenario in ['city_18_denver', 'city_15_indianapolis']:
176
+ parameters['active_BS'] = np.array([3])
177
+ else:
178
+ parameters['active_BS'] = np.array([1])
179
+
180
+ if scenario == 'Boston5G_3p5':
181
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
182
+ row_column_users[scenario]['n_rows'][1])
183
+ else:
184
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
185
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
186
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
187
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
188
+ parameters['enable_BS2BS'] = False
189
+ parameters['OFDM']['subcarriers'] = n_subcarriers
190
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
191
+
192
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
193
+ parameters['num_paths'] = 20
194
+
195
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
196
+
197
+
198
+ #%% Sample Generation
199
+ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_raw=False):
200
+ """
201
+ Generates a sample for each user, including masking and tokenizing.
202
+
203
+ Args:
204
+ user_idx (int): Index of the user.
205
+ patch (numpy array): Patches data.
206
+ word2id (dict): Dictionary for special tokens.
207
+ n_patches (int): Number of patches.
208
+ n_masks (int): Number of masks.
209
+ patch_size (int): Size of each patch.
210
+ gen_raw (bool): Whether to generate raw tokens.
211
+
212
+ Returns:
213
+ sample (list): Generated sample for the user.
214
+ """
215
+
216
+ tokens = patch[user_idx]
217
+ input_ids = np.vstack((word2id['[CLS]'], tokens))
218
+
219
+ real_tokens_size = int(n_patches / 2)
220
+ masks_pos_real = np.random.choice(range(0, real_tokens_size), size=n_masks, replace=False)
221
+ masks_pos_imag = masks_pos_real + real_tokens_size
222
+ masked_pos = np.hstack((masks_pos_real, masks_pos_imag)) + 1
223
+
224
+ masked_tokens = []
225
+ for pos in masked_pos:
226
+ original_masked_tokens = input_ids[pos].copy()
227
+ masked_tokens.append(original_masked_tokens)
228
+ if not gen_raw:
229
+ rnd_num = np.random.rand()
230
+ if rnd_num < 0.1:
231
+ input_ids[pos] = np.random.rand(patch_size)
232
+ elif rnd_num < 0.9:
233
+ input_ids[pos] = word2id['[MASK]']
234
+
235
+ return [input_ids, masked_tokens, masked_pos]
236
+
237
+
238
+ #%% Sampling and Data Selection
239
+ def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
240
+ """
241
+ Performs uniform sampling on the dataset.
242
+
243
+ Args:
244
+ dataset (dict): DeepMIMO dataset.
245
+ sampling_div (list): Step sizes along [x, y] dimensions.
246
+ n_rows (int): Number of rows for user selection.
247
+ users_per_row (int): Number of users per row.
248
+
249
+ Returns:
250
+ uniform_idxs (numpy array): Indices of the selected samples.
251
+ """
252
+ cols = np.arange(users_per_row, step=sampling_div[0])
253
+ rows = np.arange(n_rows, step=sampling_div[1])
254
+ uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
255
+
256
+ return uniform_idxs
257
+
258
+ def select_by_idx(dataset, idxs):
259
+ """
260
+ Selects a subset of the dataset based on the provided indices.
261
+
262
+ Args:
263
+ dataset (dict): Dataset to trim.
264
+ idxs (numpy array): Indices of users to select.
265
+
266
+ Returns:
267
+ dataset_t (list): Trimmed dataset based on selected indices.
268
+ """
269
+ dataset_t = [] # Trimmed dataset
270
+ for bs_idx in range(len(dataset)):
271
+ dataset_t.append({})
272
+ for key in dataset[bs_idx].keys():
273
+ dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
274
+ dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
275
+
276
+ return dataset_t
277
+
278
+ #%% Save and Load Utilities
279
+ def save_var(var, path):
280
+ """
281
+ Saves a variable to a pickle file.
282
+
283
+ Args:
284
+ var (object): Variable to be saved.
285
+ path (str): Path to save the file.
286
+
287
+ Returns:
288
+ None
289
+ """
290
+ path_full = path if path.endswith('.p') else (path + '.pickle')
291
+ with open(path_full, 'wb') as handle:
292
+ pickle.dump(var, handle)
293
+
294
+ def load_var(path):
295
+ """
296
+ Loads a variable from a pickle file.
297
+
298
+ Args:
299
+ path (str): Path of the file to load.
300
+
301
+ Returns:
302
+ var (object): Loaded variable.
303
+ """
304
+ path_full = path if path.endswith('.p') else (path + '.pickle')
305
+ with open(path_full, 'rb') as handle:
306
+ var = pickle.load(handle)
307
+
308
+ return var
309
+
310
+ #%%