Sadjad Alikhani commited on
Commit
737ab45
1 Parent(s): a377f74

Upload 3 files

Browse files
Files changed (3) hide show
  1. inference.py +0 -88
  2. input_preprocess.py +89 -3
  3. lwm_model.py +153 -151
inference.py CHANGED
@@ -23,7 +23,6 @@ import numpy as np
23
  #from lwm_model import LWM, load_model
24
  import warnings
25
  warnings.filterwarnings('ignore')
26
- from input_preprocess import *
27
 
28
  # Device configuration
29
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
@@ -68,90 +67,3 @@ def create_raw_dataset(data, device):
68
  input_data = torch.tensor(input_ids, device=device)[:, 1:]
69
  return input_data.float()
70
 
71
-
72
- def label_gen(task, data, scenario, n_beams=64):
73
-
74
- idxs = np.where(data['user']['LoS'] != -1)[0]
75
-
76
- if task == 'LoS/NLoS Classification':
77
- label = data['user']['LoS'][idxs]
78
- elif task == 'Beam Prediction':
79
- parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
80
- n_users = len(data['user']['channel'])
81
- n_subbands = 1
82
- fov = 120
83
-
84
- # Setup Beamformers
85
- beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
86
-
87
- F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
88
- phi=azi*np.pi/180,
89
- kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
90
- for azi in beam_angles])
91
-
92
- full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
93
- for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
94
- if data['user']['LoS'][ue_idx] == -1:
95
- full_dbm[:,:,ue_idx] = np.nan
96
- else:
97
- chs = F1 @ data['user']['channel'][ue_idx]
98
- full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
99
- full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
100
-
101
- best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
102
- best_beams = best_beams.astype(float)
103
- best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
104
- max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
105
-
106
- label = best_beams[idxs]
107
-
108
- return label.astype(int)
109
-
110
-
111
- def steering_vec(array, phi=0, theta=0, kd=np.pi):
112
- # phi = azimuth
113
- # theta = elevation
114
- idxs = DeepMIMOv3.ant_indices(array)
115
- resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
116
- return resp / np.linalg.norm(resp)
117
-
118
-
119
- def evaluate(model, dataloader):
120
-
121
- model.eval()
122
- running_loss = 0.0
123
- outputs = []
124
- criterionMCM = nn.MSELoss()
125
-
126
- with torch.no_grad():
127
- for batch in dataloader:
128
- input_ids = batch[0]
129
- masked_tokens = batch[1]
130
- masked_pos = batch[2]
131
-
132
- logits_lm, output = model(input_ids, masked_pos)
133
-
134
- output_batch_preproc = output
135
- outputs.append(output_batch_preproc)
136
-
137
- loss_lm = criterionMCM(logits_lm, masked_tokens)
138
- loss = loss_lm/torch.var(masked_tokens)
139
- running_loss += loss.item()
140
-
141
- average_loss = running_loss / len(dataloader)
142
- output_total = torch.cat(outputs, dim=0)
143
-
144
- return average_loss, output_total
145
-
146
-
147
- def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
148
- labels = []
149
- for scenario_idx in scenario_idxs:
150
- scenario_name = scenarios_list()[scenario_idx]
151
- # data = DeepMIMO_data_gen(scenario_name)
152
- data = deepmimo_data[scenario_idx]
153
- labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
154
-
155
- preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
156
-
157
- return preprocessed_chs
 
23
  #from lwm_model import LWM, load_model
24
  import warnings
25
  warnings.filterwarnings('ignore')
 
26
 
27
  # Device configuration
28
  device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
 
67
  input_data = torch.tensor(input_ids, device=device)[:, 1:]
68
  return input_data.float()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
input_preprocess.py CHANGED
@@ -59,8 +59,8 @@ def tokenizer(selected_scenario_names=None, manual_data=None, gen_raw=True):
59
  patch_size = patches.shape[2]
60
  n_patches = patches.shape[1]
61
  n_masks_half = int(0.15 * n_patches / 2)
62
- sequence_length = n_patches + 1
63
- element_length = patch_size
64
 
65
  word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
66
 
@@ -307,4 +307,90 @@ def load_var(path):
307
 
308
  return var
309
 
310
- #%%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  patch_size = patches.shape[2]
60
  n_patches = patches.shape[1]
61
  n_masks_half = int(0.15 * n_patches / 2)
62
+ # sequence_length = n_patches + 1
63
+ # element_length = patch_size
64
 
65
  word2id = {'[CLS]': 0.2 * np.ones((patch_size)), '[MASK]': 0.1 * np.ones((patch_size))}
66
 
 
307
 
308
  return var
309
 
310
+ #%%
311
+
312
+ def label_gen(task, data, scenario, n_beams=64):
313
+
314
+ idxs = np.where(data['user']['LoS'] != -1)[0]
315
+
316
+ if task == 'LoS/NLoS Classification':
317
+ label = data['user']['LoS'][idxs]
318
+ elif task == 'Beam Prediction':
319
+ parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers = get_parameters(scenario)
320
+ n_users = len(data['user']['channel'])
321
+ n_subbands = 1
322
+ fov = 120
323
+
324
+ # Setup Beamformers
325
+ beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
326
+
327
+ F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
328
+ phi=azi*np.pi/180,
329
+ kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
330
+ for azi in beam_angles])
331
+
332
+ full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
333
+ for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
334
+ if data['user']['LoS'][ue_idx] == -1:
335
+ full_dbm[:,:,ue_idx] = np.nan
336
+ else:
337
+ chs = F1 @ data['user']['channel'][ue_idx]
338
+ full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
339
+ full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
340
+
341
+ best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
342
+ best_beams = best_beams.astype(float)
343
+ best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
344
+ # max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
345
+
346
+ label = best_beams[idxs]
347
+
348
+ return label.astype(int)
349
+
350
+ def steering_vec(array, phi=0, theta=0, kd=np.pi):
351
+ # phi = azimuth
352
+ # theta = elevation
353
+ idxs = DeepMIMOv3.ant_indices(array)
354
+ resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
355
+ return resp / np.linalg.norm(resp)
356
+
357
+
358
+ def evaluate(model, dataloader):
359
+
360
+ model.eval()
361
+ running_loss = 0.0
362
+ outputs = []
363
+ criterionMCM = nn.MSELoss()
364
+
365
+ with torch.no_grad():
366
+ for batch in dataloader:
367
+ input_ids = batch[0]
368
+ masked_tokens = batch[1]
369
+ masked_pos = batch[2]
370
+
371
+ logits_lm, output = model(input_ids, masked_pos)
372
+
373
+ output_batch_preproc = output
374
+ outputs.append(output_batch_preproc)
375
+
376
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
377
+ loss = loss_lm/torch.var(masked_tokens)
378
+ running_loss += loss.item()
379
+
380
+ average_loss = running_loss / len(dataloader)
381
+ output_total = torch.cat(outputs, dim=0)
382
+
383
+ return average_loss, output_total
384
+
385
+
386
+ def label_prepend(deepmimo_data, preprocessed_chs, task, scenario_idxs, n_beams=64):
387
+ labels = []
388
+ for scenario_idx in scenario_idxs:
389
+ scenario_name = scenarios_list()[scenario_idx]
390
+ # data = DeepMIMO_data_gen(scenario_name)
391
+ data = deepmimo_data[scenario_idx]
392
+ labels.extend(label_gen(task, data, scenario_name, n_beams=n_beams))
393
+
394
+ preprocessed_chs = [preprocessed_chs[i] + [labels[i]] for i in range(len(preprocessed_chs))]
395
+
396
+ return preprocessed_chs
lwm_model.py CHANGED
@@ -1,151 +1,153 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sun Sep 15 19:55:23 2024
4
-
5
- @author: salikha4
6
- """
7
-
8
- import os
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import numpy as np
13
-
14
-
15
- ELEMENT_LENGTH = 16
16
- D_MODEL = 64
17
- MAX_LEN = 129
18
- N_LAYERS = 12
19
- N_HEADS = 12
20
- D_FF = D_MODEL * 4
21
- D_K = D_MODEL // N_HEADS
22
- D_V = D_MODEL // N_HEADS
23
- DROPOUT = 0.1
24
-
25
- class LayerNormalization(nn.Module):
26
- def __init__(self, d_model: int, eps: float = 1e-6) -> None:
27
- super().__init__()
28
- self.eps = eps
29
- self.alpha = nn.Parameter(torch.ones(d_model))
30
- self.bias = nn.Parameter(torch.zeros(d_model))
31
-
32
- def forward(self, x):
33
- mean = x.mean(dim=-1, keepdim=True)
34
- std = x.std(dim=-1, keepdim=True)
35
- return self.alpha * (x - mean) / (std + self.eps) + self.bias
36
-
37
- class Embedding(nn.Module):
38
- def __init__(self, element_length, d_model, max_len):
39
- super().__init__()
40
- self.element_length = element_length
41
- self.d_model = d_model
42
- self.proj = nn.Linear(element_length, d_model)
43
- self.pos_embed = nn.Embedding(max_len, d_model)
44
- self.norm = LayerNormalization(d_model)
45
-
46
- def forward(self, x):
47
- seq_len = x.size(1)
48
- pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
49
- pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
50
- tok_emb = self.proj(x.float())
51
- embedding = tok_emb + self.pos_embed(pos)
52
- return self.norm(embedding)
53
-
54
- class ScaledDotProductAttention(nn.Module):
55
- def __init__(self):
56
- super().__init__()
57
-
58
- def forward(self, Q, K, V):
59
- scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
60
- attn = F.softmax(scores, dim=-1)
61
- context = torch.matmul(attn, V)
62
- return context, attn
63
-
64
- class MultiHeadAttention(nn.Module):
65
- def __init__(self):
66
- super().__init__()
67
- self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
68
- self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
69
- self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
70
- self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
71
- self.norm = LayerNormalization(D_MODEL)
72
- self.dropout = nn.Dropout(DROPOUT)
73
-
74
- def forward(self, Q, K, V):
75
- residual, batch_size = Q, Q.size(0)
76
- q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
77
- k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
78
- v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
79
-
80
- context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
81
- output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
82
- output = self.linear(output)
83
- return residual + self.dropout(output), attn #residual + self.dropout(output), attn
84
-
85
- class PoswiseFeedForwardNet(nn.Module):
86
- def __init__(self):
87
- super().__init__()
88
- self.fc1 = nn.Linear(D_MODEL, D_FF)
89
- self.fc2 = nn.Linear(D_FF, D_MODEL)
90
- self.dropout = nn.Dropout(DROPOUT)
91
- self.norm = LayerNormalization(D_MODEL)
92
-
93
- def forward(self, x):
94
- output = self.fc2(self.dropout(F.relu(self.fc1(x))))
95
- return x + self.dropout(output) #x + self.dropout(output)
96
-
97
- class EncoderLayer(nn.Module):
98
- def __init__(self):
99
- super().__init__()
100
- self.enc_self_attn = MultiHeadAttention()
101
- self.pos_ffn = PoswiseFeedForwardNet()
102
- self.norm = LayerNormalization(D_MODEL)
103
-
104
- def forward(self, enc_inputs):
105
- attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
106
- attn_outputs = self.norm(attn_outputs)
107
- enc_outputs = self.pos_ffn(attn_outputs)
108
- return enc_outputs, attn
109
-
110
- class LWM(torch.nn.Module):
111
- def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
112
- super().__init__()
113
- self.embedding = Embedding(element_length, d_model, max_len)
114
- self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
115
- self.linear = nn.Linear(d_model, d_model)
116
- self.norm = LayerNormalization(d_model)
117
-
118
- embed_weight = self.embedding.proj.weight
119
- d_model, n_dim = embed_weight.size()
120
- self.decoder = nn.Linear(d_model, n_dim, bias=False)
121
- self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
122
- self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
123
-
124
- @classmethod
125
- def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
126
- # Define model
127
- model = cls().to(device)
128
-
129
- # Download model weights using Hugging Face Hub
130
- # ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
131
- ckpt_path = ckpt_name
132
-
133
- # Load the model weights
134
- model.load_state_dict(torch.load(ckpt_path, map_location=device))
135
- print(f"Model loaded successfully from {ckpt_path} to {device}")
136
-
137
- return model
138
-
139
- def forward(self, input_ids, masked_pos):
140
- # Forward pass
141
- output = self.embedding(input_ids)
142
- for layer in self.layers:
143
- output, _ = layer(output)
144
-
145
- masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
146
- h_masked = torch.gather(output, 1, masked_pos)
147
- h_masked = self.norm(F.relu(self.linear(h_masked)))
148
- logits_lm = self.decoder(h_masked) + self.decoder_bias
149
-
150
- return logits_lm, output
151
-
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 19:55:23 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ # import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ # from inference import *
14
+ # from input_preprocess import *
15
+
16
+
17
+ ELEMENT_LENGTH = 16
18
+ D_MODEL = 64
19
+ MAX_LEN = 129
20
+ N_LAYERS = 12
21
+ N_HEADS = 12
22
+ D_FF = D_MODEL * 4
23
+ D_K = D_MODEL // N_HEADS
24
+ D_V = D_MODEL // N_HEADS
25
+ DROPOUT = 0.1
26
+
27
+ class LayerNormalization(nn.Module):
28
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.alpha = nn.Parameter(torch.ones(d_model))
32
+ self.bias = nn.Parameter(torch.zeros(d_model))
33
+
34
+ def forward(self, x):
35
+ mean = x.mean(dim=-1, keepdim=True)
36
+ std = x.std(dim=-1, keepdim=True)
37
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
38
+
39
+ class Embedding(nn.Module):
40
+ def __init__(self, element_length, d_model, max_len):
41
+ super().__init__()
42
+ self.element_length = element_length
43
+ self.d_model = d_model
44
+ self.proj = nn.Linear(element_length, d_model)
45
+ self.pos_embed = nn.Embedding(max_len, d_model)
46
+ self.norm = LayerNormalization(d_model)
47
+
48
+ def forward(self, x):
49
+ seq_len = x.size(1)
50
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
51
+ pos = pos.unsqueeze(0).expand_as(x[:, :, 0])
52
+ tok_emb = self.proj(x.float())
53
+ embedding = tok_emb + self.pos_embed(pos)
54
+ return self.norm(embedding)
55
+
56
+ class ScaledDotProductAttention(nn.Module):
57
+ def __init__(self):
58
+ super().__init__()
59
+
60
+ def forward(self, Q, K, V):
61
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(D_K)
62
+ attn = F.softmax(scores, dim=-1)
63
+ context = torch.matmul(attn, V)
64
+ return context, attn
65
+
66
+ class MultiHeadAttention(nn.Module):
67
+ def __init__(self):
68
+ super().__init__()
69
+ self.W_Q = nn.Linear(D_MODEL, D_K * N_HEADS)
70
+ self.W_K = nn.Linear(D_MODEL, D_K * N_HEADS)
71
+ self.W_V = nn.Linear(D_MODEL, D_V * N_HEADS)
72
+ self.linear = nn.Linear(N_HEADS * D_V, D_MODEL)
73
+ self.norm = LayerNormalization(D_MODEL)
74
+ self.dropout = nn.Dropout(DROPOUT)
75
+
76
+ def forward(self, Q, K, V):
77
+ residual, batch_size = Q, Q.size(0)
78
+ q_s = self.W_Q(Q).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
79
+ k_s = self.W_K(K).view(batch_size, -1, N_HEADS, D_K).transpose(1, 2)
80
+ v_s = self.W_V(V).view(batch_size, -1, N_HEADS, D_V).transpose(1, 2)
81
+
82
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s)
83
+ output = context.transpose(1, 2).contiguous().view(batch_size, -1, N_HEADS * D_V)
84
+ output = self.linear(output)
85
+ return residual + self.dropout(output), attn #residual + self.dropout(output), attn
86
+
87
+ class PoswiseFeedForwardNet(nn.Module):
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.fc1 = nn.Linear(D_MODEL, D_FF)
91
+ self.fc2 = nn.Linear(D_FF, D_MODEL)
92
+ self.dropout = nn.Dropout(DROPOUT)
93
+ self.norm = LayerNormalization(D_MODEL)
94
+
95
+ def forward(self, x):
96
+ output = self.fc2(self.dropout(F.relu(self.fc1(x))))
97
+ return x + self.dropout(output) #x + self.dropout(output)
98
+
99
+ class EncoderLayer(nn.Module):
100
+ def __init__(self):
101
+ super().__init__()
102
+ self.enc_self_attn = MultiHeadAttention()
103
+ self.pos_ffn = PoswiseFeedForwardNet()
104
+ self.norm = LayerNormalization(D_MODEL)
105
+
106
+ def forward(self, enc_inputs):
107
+ attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
108
+ attn_outputs = self.norm(attn_outputs)
109
+ enc_outputs = self.pos_ffn(attn_outputs)
110
+ return enc_outputs, attn
111
+
112
+ class LWM(torch.nn.Module):
113
+ def __init__(self, element_length=16, d_model=64, max_len=129, n_layers=12):
114
+ super().__init__()
115
+ self.embedding = Embedding(element_length, d_model, max_len)
116
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
117
+ self.linear = nn.Linear(d_model, d_model)
118
+ self.norm = LayerNormalization(d_model)
119
+
120
+ embed_weight = self.embedding.proj.weight
121
+ d_model, n_dim = embed_weight.size()
122
+ self.decoder = nn.Linear(d_model, n_dim, bias=False)
123
+ self.decoder.weight = nn.Parameter(embed_weight.transpose(0, 1))
124
+ self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
125
+
126
+ @classmethod
127
+ def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda', use_auth_token=None):
128
+ # Define model
129
+ model = cls().to(device)
130
+
131
+ # Download model weights using Hugging Face Hub
132
+ # ckpt_path = hf_hub_download(repo_id="sadjadalikhani/LWM", filename=ckpt_name, use_auth_token=use_auth_token)
133
+ ckpt_path = ckpt_name
134
+
135
+ # Load the model weights
136
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
137
+ print(f"Model loaded successfully from {ckpt_path} to {device}")
138
+
139
+ return model
140
+
141
+ def forward(self, input_ids, masked_pos):
142
+ # Forward pass
143
+ output = self.embedding(input_ids)
144
+ for layer in self.layers:
145
+ output, _ = layer(output)
146
+
147
+ masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
148
+ h_masked = torch.gather(output, 1, masked_pos)
149
+ h_masked = self.norm(F.relu(self.linear(h_masked)))
150
+ logits_lm = self.decoder(h_masked) + self.decoder_bias
151
+
152
+ return logits_lm, output
153
+