aisyahhrazak
commited on
Commit
•
af53a9f
1
Parent(s):
bb4e038
Upload MM_LLMs
Browse files- config.json +6 -2
- model.safetensors +2 -2
- modeling.py +385 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-
|
3 |
"architectures": [
|
4 |
"MM_LLMs"
|
5 |
],
|
@@ -206,6 +206,10 @@
|
|
206 |
},
|
207 |
"audio_conv_kernel": 240,
|
208 |
"audio_conv_stride": 220,
|
|
|
|
|
|
|
|
|
209 |
"hidden_size": 2048,
|
210 |
"image_config": {
|
211 |
"_name_or_path": "google/siglip-base-patch16-224",
|
@@ -503,6 +507,6 @@
|
|
503 |
},
|
504 |
"model_type": "mm_llms",
|
505 |
"n_frames": 6,
|
506 |
-
"torch_dtype": "
|
507 |
"transformers_version": "4.37.1"
|
508 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-450",
|
3 |
"architectures": [
|
4 |
"MM_LLMs"
|
5 |
],
|
|
|
206 |
},
|
207 |
"audio_conv_kernel": 240,
|
208 |
"audio_conv_stride": 220,
|
209 |
+
"auto_map": {
|
210 |
+
"AutoConfig": "modeling.MM_LLMs_Config",
|
211 |
+
"AutoModel": "modeling.MM_LLMs"
|
212 |
+
},
|
213 |
"hidden_size": 2048,
|
214 |
"image_config": {
|
215 |
"_name_or_path": "google/siglip-base-patch16-224",
|
|
|
507 |
},
|
508 |
"model_type": "mm_llms",
|
509 |
"n_frames": 6,
|
510 |
+
"torch_dtype": "bfloat16",
|
511 |
"transformers_version": "4.37.1"
|
512 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2d41892dc7281cb5a6667ac3d6f59cd55dc4bdb2b62832347fe305ae63a40af
|
3 |
+
size 3509162622
|
modeling.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
import copy
|
9 |
+
import math
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
|
12 |
+
from transformers import CONFIG_MAPPING
|
13 |
+
from transformers.modeling_outputs import BaseModelOutput
|
14 |
+
from transformers import GenerationConfig
|
15 |
+
from transformers import CLIPConfig, CLIPProcessor, CLIPModel, AutoModel
|
16 |
+
from transformers import WhisperConfig, WhisperPreTrainedModel, WhisperModel
|
17 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
|
18 |
+
|
19 |
+
|
20 |
+
def most_frequent_element(tensor):
|
21 |
+
flattened_list = tensor.flatten().tolist()
|
22 |
+
counter = Counter(flattened_list)
|
23 |
+
most_common_element = counter.most_common(1)[0][1]
|
24 |
+
|
25 |
+
return most_common_element
|
26 |
+
|
27 |
+
|
28 |
+
class MM_LLMs_Config(PretrainedConfig):
|
29 |
+
model_type = 'mm_llms'
|
30 |
+
is_composition = True
|
31 |
+
|
32 |
+
def __init__(self, attention_heads=8, image_conv_kernel=48, image_conv_stride=36,
|
33 |
+
audio_conv_kernel=240, audio_conv_stride=220,
|
34 |
+
image_config=None, audio_config=None, llm_config=None, **kwargs):
|
35 |
+
|
36 |
+
self.image_config = image_config
|
37 |
+
self.audio_config = audio_config
|
38 |
+
self.llm_config = llm_config
|
39 |
+
self.attention_heads = attention_heads
|
40 |
+
self.image_conv_kernel = image_conv_kernel
|
41 |
+
self.image_conv_stride = image_conv_stride
|
42 |
+
self.audio_conv_kernel = audio_conv_kernel
|
43 |
+
self.audio_conv_stride = audio_conv_stride
|
44 |
+
|
45 |
+
if isinstance(self.image_config, dict):
|
46 |
+
image_config["model_type"] = (
|
47 |
+
image_config["model_type"] if "model_type" in image_config else "clip"
|
48 |
+
)
|
49 |
+
self.image_config = CONFIG_MAPPING[image_config["model_type"]](**image_config)
|
50 |
+
if isinstance(self.audio_config, dict):
|
51 |
+
audio_config["model_type"] = (
|
52 |
+
audio_config["model_type"] if "model_type" in audio_config else "whisper"
|
53 |
+
)
|
54 |
+
self.audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
|
55 |
+
if isinstance(self.llm_config, dict):
|
56 |
+
llm_config["model_type"] = llm_config["model_type"] if "model_type" in llm_config else "llama"
|
57 |
+
self.llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
|
58 |
+
|
59 |
+
self.hidden_size = max(
|
60 |
+
self.llm_config.hidden_size,
|
61 |
+
self.image_config.vision_config.hidden_size,
|
62 |
+
self.audio_config.d_model,
|
63 |
+
)
|
64 |
+
|
65 |
+
super().__init__(**kwargs)
|
66 |
+
|
67 |
+
|
68 |
+
class MM_LLMs(PreTrainedModel):
|
69 |
+
config_class = MM_LLMs_Config
|
70 |
+
supports_gradient_checkpointing = True
|
71 |
+
_supports_flash_attn_2 = True
|
72 |
+
|
73 |
+
def __init__(self, config):
|
74 |
+
super().__init__(config)
|
75 |
+
self.config = config
|
76 |
+
|
77 |
+
self.image_encoder = AutoModel.from_config(config.image_config)
|
78 |
+
|
79 |
+
self.audio_encoder = AutoModel.from_config(config.audio_config)
|
80 |
+
|
81 |
+
self.llm = AutoModelForCausalLM.from_config(config.llm_config)
|
82 |
+
|
83 |
+
attn_dropout = 0.1
|
84 |
+
is_add_bias_kv = True
|
85 |
+
is_add_zero_attn = True
|
86 |
+
self.temporal_self_attention = nn.MultiheadAttention(
|
87 |
+
config.image_config.text_config.hidden_size,
|
88 |
+
config.attention_heads,
|
89 |
+
dropout=attn_dropout,
|
90 |
+
add_bias_kv=is_add_bias_kv,
|
91 |
+
add_zero_attn=is_add_zero_attn)
|
92 |
+
|
93 |
+
self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
|
94 |
+
config.attention_heads * 2,
|
95 |
+
dropout=attn_dropout,
|
96 |
+
add_bias_kv=is_add_bias_kv,
|
97 |
+
add_zero_attn=is_add_zero_attn)
|
98 |
+
|
99 |
+
self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
|
100 |
+
config.attention_heads * 2,
|
101 |
+
dropout=attn_dropout,
|
102 |
+
add_bias_kv=is_add_bias_kv,
|
103 |
+
add_zero_attn=is_add_zero_attn)
|
104 |
+
|
105 |
+
self.transform_audio_to_hidden = nn.Linear(config.audio_config.d_model,
|
106 |
+
config.llm_config.hidden_size)
|
107 |
+
self.transform_image_to_hidden = nn.Linear(config.image_config.text_config.hidden_size,
|
108 |
+
config.llm_config.hidden_size)
|
109 |
+
|
110 |
+
self.project_image = nn.Conv1d(
|
111 |
+
config.image_config.text_config.hidden_size,
|
112 |
+
config.image_config.text_config.hidden_size,
|
113 |
+
kernel_size=config.image_conv_kernel,
|
114 |
+
stride=config.image_conv_stride)
|
115 |
+
self.project_audio = nn.Conv1d(
|
116 |
+
config.audio_config.d_model,
|
117 |
+
config.audio_config.d_model,
|
118 |
+
kernel_size=config.audio_conv_kernel,
|
119 |
+
stride=config.audio_conv_stride)
|
120 |
+
|
121 |
+
self.visual_projection = nn.Linear(
|
122 |
+
self.image_encoder.vision_model.config.hidden_size,
|
123 |
+
self.config.image_config.text_config.hidden_size,
|
124 |
+
bias=False)
|
125 |
+
|
126 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
127 |
+
|
128 |
+
self.layer_norm = nn.LayerNorm(config.image_config.text_config.hidden_size)
|
129 |
+
self.softmax = nn.Softmax(dim=-1)
|
130 |
+
|
131 |
+
self.sigmoid = nn.Sigmoid()
|
132 |
+
|
133 |
+
self.loss_fct = CrossEntropyLoss()
|
134 |
+
|
135 |
+
self.init_weights()
|
136 |
+
|
137 |
+
def forward(self,
|
138 |
+
input_ids: torch.LongTensor = None,
|
139 |
+
image_index: torch.LongTensor = None,
|
140 |
+
audio_index: torch.LongTensor = None,
|
141 |
+
image_starts: torch.int = None,
|
142 |
+
image_ends: torch.int = None,
|
143 |
+
audio_starts: torch.int = None,
|
144 |
+
audio_ends: torch.int = None,
|
145 |
+
images: torch.FloatTensor = None,
|
146 |
+
audios: torch.FloatTensor = None,
|
147 |
+
attention_mask: Optional[torch.Tensor] = None,
|
148 |
+
position_ids: Optional[torch.LongTensor] = None,
|
149 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
150 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
151 |
+
labels: Optional[torch.LongTensor] = None,
|
152 |
+
output_attentions: Optional[bool] = None,
|
153 |
+
output_hidden_states: Optional[bool] = None,
|
154 |
+
use_cache: Optional[bool] = None,
|
155 |
+
return_dict: Optional[bool] = None):
|
156 |
+
|
157 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
158 |
+
|
159 |
+
images = images.type(self.image_encoder.dtype) if images is not None else None
|
160 |
+
audios = audios.type(self.audio_encoder.dtype) if audios is not None else None
|
161 |
+
|
162 |
+
model_inputs = self.prepare_inputs_for_generation(
|
163 |
+
input_ids=input_ids,
|
164 |
+
image_index=image_index,
|
165 |
+
audio_index=audio_index,
|
166 |
+
image_starts=image_starts,
|
167 |
+
image_ends=image_ends,
|
168 |
+
audio_starts=audio_starts,
|
169 |
+
audio_ends=audio_ends,
|
170 |
+
images=images,
|
171 |
+
audios=audios,
|
172 |
+
attention_mask=attention_mask,
|
173 |
+
labels=labels)
|
174 |
+
|
175 |
+
outputs = self.llm(
|
176 |
+
inputs_embeds=model_inputs['inputs_embeds'],
|
177 |
+
attention_mask=model_inputs['attention_mask'],
|
178 |
+
labels=model_inputs['labels'],
|
179 |
+
return_dict=return_dict)
|
180 |
+
|
181 |
+
return outputs
|
182 |
+
|
183 |
+
def prepare_inputs_for_generation(
|
184 |
+
self,
|
185 |
+
input_ids,
|
186 |
+
past_key_values=None,
|
187 |
+
inputs_embeds=None,
|
188 |
+
images=None,
|
189 |
+
audios=None,
|
190 |
+
audio_starts=None,
|
191 |
+
audio_ends=None,
|
192 |
+
image_starts=None,
|
193 |
+
image_ends=None,
|
194 |
+
attention_mask=None,
|
195 |
+
labels=None,
|
196 |
+
audio_index=None,
|
197 |
+
image_index=None,
|
198 |
+
**kwargs):
|
199 |
+
|
200 |
+
image_features = self.encode_image(
|
201 |
+
images) if images is not None else None
|
202 |
+
audio_features = self.encode_audio(
|
203 |
+
audios) if audios is not None else None
|
204 |
+
embed_tokens = self.llm.model.embed_tokens
|
205 |
+
text_embeddings = embed_tokens(input_ids)
|
206 |
+
|
207 |
+
token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(
|
208 |
+
text_embeddings.size(0), 1, 1).transpose(0, 1)
|
209 |
+
|
210 |
+
ingore_num = 0
|
211 |
+
|
212 |
+
if audio_features is not None:
|
213 |
+
|
214 |
+
audio_starts = embed_tokens(audio_starts).unsqueeze(1)
|
215 |
+
audio_ends = embed_tokens(audio_ends).unsqueeze(1)
|
216 |
+
|
217 |
+
audio_features = self.project_audio(
|
218 |
+
audio_features.transpose(
|
219 |
+
1, 2).contiguous()).transpose(
|
220 |
+
1, 2).contiguous()
|
221 |
+
|
222 |
+
audio_features = self.transform_audio_to_hidden(audio_features)
|
223 |
+
|
224 |
+
max_count = most_frequent_element(audio_index)
|
225 |
+
|
226 |
+
seq_img = audio_features.shape[1]
|
227 |
+
dim = token_embeddings.shape[2]
|
228 |
+
|
229 |
+
new_audio = torch.zeros(
|
230 |
+
(token_embeddings.shape[1],
|
231 |
+
seq_img * max_count,
|
232 |
+
dim),
|
233 |
+
device=token_embeddings.device,
|
234 |
+
dtype=token_embeddings.dtype)
|
235 |
+
current_dim = 0
|
236 |
+
for no, index in enumerate(audio_index):
|
237 |
+
if no > 0 and audio_index[no - 1] == index:
|
238 |
+
current_dim += 1
|
239 |
+
else:
|
240 |
+
current_dim = 0
|
241 |
+
new_audio[index, current_dim *
|
242 |
+
seq_img: (current_dim + 1) * seq_img] = audio_features[no]
|
243 |
+
last_index = audio_index[0]
|
244 |
+
|
245 |
+
audio_features = self.audio_align_attention(
|
246 |
+
new_audio.transpose(
|
247 |
+
0,
|
248 |
+
1),
|
249 |
+
token_embeddings,
|
250 |
+
token_embeddings)[0].transpose(
|
251 |
+
0,
|
252 |
+
1).contiguous()
|
253 |
+
|
254 |
+
# audio_features = add_positional_encoding(audio_features)
|
255 |
+
|
256 |
+
audio_inputs = torch.cat(
|
257 |
+
[torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)
|
258 |
+
|
259 |
+
text_embeddings = torch.cat(
|
260 |
+
[torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
|
261 |
+
dim=1)
|
262 |
+
|
263 |
+
ingore_num += (audio_inputs.size(1))
|
264 |
+
|
265 |
+
if image_features is not None:
|
266 |
+
|
267 |
+
image_starts = embed_tokens(image_starts).unsqueeze(1)
|
268 |
+
image_ends = embed_tokens(image_ends).unsqueeze(1)
|
269 |
+
|
270 |
+
image_features = self.project_image(
|
271 |
+
image_features.transpose(
|
272 |
+
1, 2).contiguous()).transpose(
|
273 |
+
1, 2).contiguous()
|
274 |
+
|
275 |
+
image_features = self.transform_image_to_hidden(image_features)
|
276 |
+
|
277 |
+
max_count = most_frequent_element(image_index)
|
278 |
+
|
279 |
+
seq_img = image_features.shape[1]
|
280 |
+
dim = token_embeddings.shape[2]
|
281 |
+
|
282 |
+
new_img = torch.zeros(
|
283 |
+
(token_embeddings.shape[1],
|
284 |
+
seq_img * max_count,
|
285 |
+
dim),
|
286 |
+
device=token_embeddings.device,
|
287 |
+
dtype=token_embeddings.dtype)
|
288 |
+
|
289 |
+
current_dim = 0
|
290 |
+
for no, index in enumerate(image_index):
|
291 |
+
if no > 0 and image_index[no - 1] == index:
|
292 |
+
current_dim += 1
|
293 |
+
else:
|
294 |
+
current_dim = 0
|
295 |
+
new_img[index, current_dim *
|
296 |
+
seq_img: (current_dim + 1) * seq_img] = image_features[no]
|
297 |
+
last_index = image_index[0]
|
298 |
+
|
299 |
+
image_features = self.image_align_attention(
|
300 |
+
new_img.transpose(
|
301 |
+
0,
|
302 |
+
1),
|
303 |
+
token_embeddings,
|
304 |
+
token_embeddings)[0].transpose(
|
305 |
+
0,
|
306 |
+
1).contiguous()
|
307 |
+
|
308 |
+
# image_features = add_positional_encoding(image_features)
|
309 |
+
|
310 |
+
image_inputs = torch.cat(
|
311 |
+
[torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)
|
312 |
+
|
313 |
+
text_embeddings = torch.cat(
|
314 |
+
[torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1),
|
315 |
+
text_embeddings[:, 1:, :]], dim=1)
|
316 |
+
|
317 |
+
ingore_num += (image_inputs.size(1))
|
318 |
+
|
319 |
+
if attention_mask is not None:
|
320 |
+
attentionmask = torch.tensor([1]*ingore_num*text_embeddings.size(0),
|
321 |
+
device=text_embeddings.device).view(text_embeddings.size(0), -1)
|
322 |
+
attentionmask = torch.cat([attentionmask, attention_mask], dim=1)
|
323 |
+
else:
|
324 |
+
attention_mask = None
|
325 |
+
|
326 |
+
if labels is not None:
|
327 |
+
labels_ = torch.tensor([-100]*ingore_num*text_embeddings.size(0),
|
328 |
+
device=text_embeddings.device).view(text_embeddings.size(0), -1)
|
329 |
+
labels = torch.cat([labels_, labels], dim=1)
|
330 |
+
else:
|
331 |
+
labels = None
|
332 |
+
|
333 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
334 |
+
if inputs_embeds is not None and past_key_values is None:
|
335 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
336 |
+
else:
|
337 |
+
model_inputs = {"input_ids": input_ids}
|
338 |
+
|
339 |
+
model_inputs.update(
|
340 |
+
{
|
341 |
+
"inputs_embeds": text_embeddings,
|
342 |
+
"use_cache": kwargs.get("use_cache"),
|
343 |
+
"attention_mask": attentionmask,
|
344 |
+
"labels": labels,
|
345 |
+
}
|
346 |
+
)
|
347 |
+
return model_inputs
|
348 |
+
|
349 |
+
def encode_audio(self, audios):
|
350 |
+
audio_features = self.audio_encoder.encoder(audios)
|
351 |
+
return audio_features[0]
|
352 |
+
|
353 |
+
def encode_image(self, images):
|
354 |
+
|
355 |
+
image_features = self.visual_projection(
|
356 |
+
self.image_encoder.vision_model(images)[0])[:, 1:, :]
|
357 |
+
|
358 |
+
return image_features
|
359 |
+
|
360 |
+
|
361 |
+
def create_positional_encoding(L, h):
|
362 |
+
# Create a tensor to store the position encoding
|
363 |
+
position_encoding = torch.zeros(L, h)
|
364 |
+
|
365 |
+
# Fill the position encoding tensor
|
366 |
+
for pos in range(L):
|
367 |
+
for i in range(0, h, 2):
|
368 |
+
div_term = torch.exp(torch.tensor(-(math.log(10000.0) / h * (2 * i))))
|
369 |
+
position_encoding[pos, i] = torch.sin(pos * div_term)
|
370 |
+
position_encoding[pos, i + 1] = torch.cos(pos * div_term)
|
371 |
+
|
372 |
+
return position_encoding
|
373 |
+
|
374 |
+
|
375 |
+
def add_positional_encoding(tensor):
|
376 |
+
N, L, h = tensor.size() # batch size, sequence length, and feature dimension
|
377 |
+
|
378 |
+
# Create position embedding tensor
|
379 |
+
position_embedding = create_positional_encoding(L, h).to(tensor.device).to(tensor.dtype)
|
380 |
+
|
381 |
+
# Expand position embedding to match input tensor dimensions
|
382 |
+
position_embedding = position_embedding.unsqueeze(0).expand(N, -1, -1)
|
383 |
+
|
384 |
+
# Add position embedding to the input tensor
|
385 |
+
return tensor + position_embedding
|