sudy-super commited on
Commit
935e343
1 Parent(s): 61db68f

Upload 2 files

Browse files
Files changed (2) hide show
  1. modeling_co_encoder.py +592 -0
  2. tokenization_co_encoder.py +194 -0
modeling_co_encoder.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """PyTorch CoEncoder model."""
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers.activations import ACT2FN
15
+ from transformers.image_processing_utils import select_best_resolution
16
+ from transformers.modeling_outputs import ModelOutput
17
+ from transformers.utils import (
18
+ add_start_docstrings,
19
+ add_start_docstrings_to_model_forward,
20
+ logging,
21
+ replace_return_docstrings,
22
+ is_flash_attn_2_available,
23
+ )
24
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
25
+ from .configuration_co_encoder import CoEncoderConfig
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ _CONFIG_FOR_DOC = "CoEncoderConfig"
31
+
32
+
33
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
34
+ """
35
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
36
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
37
+ """
38
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
39
+ if n_rep == 1:
40
+ return hidden_states
41
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
42
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
43
+
44
+
45
+ @dataclass
46
+ class CoEncoderCausalLMOutputWithPast(ModelOutput):
47
+ """
48
+ Base class for CoEncoder causal language model (or autoregressive) outputs.
49
+
50
+ Args:
51
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
52
+ Language modeling loss (for next-token prediction).
53
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
54
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
55
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
56
+ Tuple of `tuple(torch.FloatTensor)` of length `config.context_config.num_layers`, with each tuple having 2 tensors of shape
57
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
58
+
59
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
60
+ `past_key_values` input) to speed up sequential decoding.
61
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
62
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
63
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
64
+
65
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
66
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
67
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
68
+ sequence_length)`.
69
+
70
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
71
+ heads.
72
+ context_hidden_states (`torch.FloatTensor`, *optional*):
73
+ A `torch.FloatTensor` of size (batch_size, sequence_length, hidden_size)`.
74
+ context_hidden_states of the model produced by the context encoder and after projecting the last hidden state.
75
+ """
76
+
77
+ loss: Optional[torch.FloatTensor] = None
78
+ logits: torch.FloatTensor = None
79
+ past_key_values: Optional[List[torch.FloatTensor]] = None
80
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
81
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
82
+ context_hidden_states: Optional[torch.FloatTensor] = None
83
+
84
+
85
+ class CoEncoderDynamicAttention(nn.Module):
86
+ """
87
+ Attention mechanism adapted for dynamic output size based on Mistral's architecture. This attention layer computes
88
+ the output attention scores which are used to determine the pooling size dynamically.
89
+ """
90
+
91
+ def __init__(self, config: CoEncoderConfig):
92
+ super().__init__()
93
+
94
+ self.hidden_size = config.context_config.hidden_size
95
+ self.num_heads = config.context_config.num_attention_heads
96
+ self.head_dim = getattr(config.context_config, "head_dim", self.hidden_size // self.num_heads)
97
+ self.num_key_value_heads = config.context_config.num_key_value_heads
98
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
99
+
100
+ # Query, Key, Value, and Output Projections
101
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
102
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
103
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
104
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, 1, bias=False)
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states,
109
+ output_attentions=False,
110
+ ):
111
+ # Get input dimensions
112
+ bsz, seq_len, hidden_size = hidden_states.size()
113
+
114
+ # Query, Key, Value projections
115
+ query_states = self.q_proj(hidden_states)
116
+ key_states = self.k_proj(hidden_states)
117
+ value_states = self.v_proj(hidden_states)
118
+
119
+ # Reshape and transpose to [batch_size, num_heads, seq_len, head_dim]
120
+ query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
121
+ key_states = key_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
122
+ value_states = value_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
123
+
124
+ # Repeat key and value states for multi-head attention
125
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
126
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
127
+
128
+ # Compute attention scores
129
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
130
+
131
+ # Apply softmax to get attention probabilities
132
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
133
+
134
+ # Apply attention to values
135
+ attn_output = torch.matmul(attn_weights, value_states)
136
+
137
+ # Reshape attention output
138
+ attn_output = attn_output.transpose(1, 2).contiguous()
139
+ attn_output = attn_output.reshape(bsz, seq_len, -1)
140
+
141
+ # Project to output dimension
142
+ attn_output = self.o_proj(attn_output)
143
+
144
+ if not output_attentions:
145
+ attn_weights = None
146
+
147
+ return attn_output, attn_weights
148
+
149
+
150
+ class CoEncoderDynamicWeightedAvgPool1d(nn.Module):
151
+ """
152
+ A module that dynamically determines the output size based on input
153
+ and performs weighted average pooling with separate attention mechanisms
154
+ for output size estimation and weighted pooling.
155
+ """
156
+ def __init__(self, config, output_size_min=32, output_size_max=131072):
157
+ super().__init__()
158
+ # Attention mechanism for estimating output size
159
+ self.size_estimation_attention = CoEncoderDynamicAttention(config)
160
+ # Attention mechanism for weighted pooling
161
+ self.weighted_pooling_attention = CoEncoderDynamicAttention(config)
162
+ self.output_size_min = output_size_min
163
+ self.output_size_max = (
164
+ config.context_config.max_position_embeddings if config.context_config.max_position_embeddings is not None else output_size_max
165
+ )
166
+ self.scale_param = nn.Parameter(torch.tensor(0.01))
167
+
168
+ def forward(self, hidden_states):
169
+ """
170
+ Args:
171
+ x: Input tensor of shape (batch_size, seq_len, hidden_size)
172
+
173
+ Returns:
174
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
175
+ - pooled_output: Padded tensor of compressed sequences (batch_size, max_pooled_len, hidden_size)
176
+ - attention_mask: Binary mask indicating valid tokens (batch_size, max_pooled_len)
177
+ - dynamic_output_sizes: Dynamic output sizes for each batch (batch_size,)
178
+ """
179
+ batch_size, seq_len, hidden_size = hidden_states.size()
180
+ device = hidden_states.device
181
+
182
+ # Estimate output size using attention mechanism
183
+ # attn_output_size: (batch_size, seq_len, 1)
184
+ attn_output_size, _ = self.size_estimation_attention(hidden_states)
185
+
186
+ # Calculate dynamic output sizes for each batch item
187
+ # (batch_size, seq_len, 1) -> (batch_size, 1)
188
+ batch_attn_means = torch.sigmoid(attn_output_size).mean(dim=1)
189
+ scaled_batch_means = batch_attn_means * self.scale_param
190
+
191
+ # Calculate dynamic output sizes (batch_size,)
192
+ dynamic_output_sizes = (
193
+ scaled_batch_means * (self.output_size_max - self.output_size_min)
194
+ + self.output_size_min
195
+ ).int().squeeze(-1)
196
+
197
+ # Get the maximum output size across the batch
198
+ max_pooled_len = dynamic_output_sizes.max().item()
199
+
200
+ # Compute attention weights for weighted pooling
201
+ # attn_output_weights: (batch_size, seq_len, 1)
202
+ attn_output_weights, _ = self.weighted_pooling_attention(hidden_states)
203
+ # Normalize with sigmoid function for use as weights
204
+ # attention_weights: (batch_size, seq_len)
205
+ attention_weights = torch.sigmoid(attn_output_weights).squeeze(-1)
206
+
207
+ # Initialize output tensors
208
+ # pooled_output: (batch_size, max_pooled_len, hidden_size)
209
+ pooled_output = torch.zeros(batch_size, max_pooled_len, hidden_size, device=device)
210
+ # attention_mask: (batch_size, max_pooled_len)
211
+ attention_mask = torch.zeros(batch_size, max_pooled_len, dtype=torch.bool, device=device)
212
+
213
+ for batch_idx in range(batch_size):
214
+ output_size = dynamic_output_sizes[batch_idx].item()
215
+ item_input = hidden_states[batch_idx] # Shape: (seq_len, hidden_size)
216
+ item_weights = attention_weights[batch_idx] # Shape: (seq_len)
217
+
218
+ # Perform weighted pooling
219
+ pooled_values = []
220
+ # Split the sequence evenly
221
+ intervals = torch.linspace(0, seq_len, steps=output_size + 1).long()
222
+ for i in range(output_size):
223
+ start = intervals[i].item()
224
+ end = intervals[i + 1].item()
225
+ chunk_input = item_input[start:end] # Shape: (chunk_size, hidden_size)
226
+ chunk_weights = item_weights[start:end] # Shape: (chunk_size)
227
+ if chunk_weights.sum() == 0:
228
+ # If the sum of weights is zero, add a zero vector
229
+ pooled_value = torch.zeros(hidden_size, device=device)
230
+ else:
231
+ # Calculate weighted average
232
+ weighted_input = chunk_input * chunk_weights.unsqueeze(-1) # Shape: (chunk_size, hidden_size)
233
+ pooled_value = weighted_input.sum(dim=0) / (chunk_weights.sum() + 1e-8) # Shape: (hidden_size)
234
+ pooled_values.append(pooled_value)
235
+ # Convert the result to a tensor
236
+ pooled_values = torch.stack(pooled_values) # Shape: (output_size, hidden_size)
237
+ # Store the result
238
+ pooled_output[batch_idx, -output_size:] = pooled_values.squeeze(0)
239
+ attention_mask[batch_idx, -output_size:] = True
240
+
241
+ return pooled_output, attention_mask, dynamic_output_sizes
242
+
243
+
244
+ class CoEncoderContextLanguageConnector(nn.Module):
245
+ def __init__(self, config: CoEncoderConfig):
246
+ super().__init__()
247
+
248
+ self.dynamic_pooling = CoEncoderDynamicWeightedAvgPool1d(config)
249
+
250
+ self.linear_1 = nn.Linear(config.context_config.hidden_size, config.text_config.hidden_size, bias=True)
251
+ self.act = ACT2FN[config.projector_hidden_act]
252
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
253
+
254
+ def forward(self, context_features):
255
+ # context_features: [batch_size, seq_len, hidden_size]
256
+ # Apply dynamic adaptive average pooling with attention
257
+ pooled_output, attention_mask, dynamic_output_sizes = self.dynamic_pooling(context_features)
258
+ # pooled_output: [batch_size, max_pooled_len, hidden_size]
259
+
260
+ hidden_states = self.linear_1(pooled_output)
261
+ hidden_states = self.act(hidden_states)
262
+ hidden_states = self.linear_2(hidden_states)
263
+
264
+ return hidden_states, attention_mask
265
+
266
+
267
+ class CoEncoderContextTower(nn.Module):
268
+ def __init__(self, config: CoEncoderConfig):
269
+ super().__init__()
270
+
271
+ self.tower = AutoModelForCausalLM.from_config(
272
+ config.context_config,
273
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
274
+ )
275
+ self.select_layer = config.context_feature_layer
276
+
277
+ def feature_select(self, llm_outputs):
278
+ hidden_states = llm_outputs.hidden_states
279
+ return hidden_states[self.select_layer]
280
+
281
+ def forward(self, inputs):
282
+ outputs = self.tower(inputs, output_hidden_states=True)
283
+ features = self.feature_select(outputs)
284
+ return features
285
+
286
+
287
+ class CoEncoderPreTrainedModel(PreTrainedModel):
288
+ config_class = CoEncoderConfig
289
+ base_model_prefix = "model"
290
+ supports_gradient_checkpointing = True
291
+ _no_split_modules = ["CoEncoderContextLanguageConnector", "CoEncoderContextTower"]
292
+ _skip_keys_device_placement = ["past_key_values"]
293
+ _supports_flash_attn_2 = True
294
+ _supports_sdpa = True
295
+ _supports_cache_class = True
296
+ _supports_quantized_cache = True
297
+ _supports_static_cache = True
298
+
299
+ def _init_weights(self, module):
300
+ std = (
301
+ self.config.initializer_range
302
+ if hasattr(self.config, "initializer_range")
303
+ else self.config.text_config.initializer_range
304
+ )
305
+ if isinstance(module, nn.Linear):
306
+ module.weight.data.normal_(mean=0.0, std=std)
307
+ if module.bias is not None:
308
+ module.bias.data.zero_()
309
+ elif isinstance(module, nn.Embedding):
310
+ module.weight.data.normal_(mean=0.0, std=std)
311
+ if module.padding_idx is not None:
312
+ module.weight.data[module.padding_idx].zero_()
313
+
314
+
315
+ class CoEncoderForConditionalGeneration(CoEncoderPreTrainedModel):
316
+ def __init__(self, config: CoEncoderConfig):
317
+ super().__init__(config)
318
+ self.context_tower = CoEncoderContextTower(config)
319
+ self.connector = CoEncoderContextLanguageConnector(config)
320
+
321
+ self.language_model = AutoModelForCausalLM.from_config(
322
+ config.text_config,
323
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "eager"
324
+ )
325
+
326
+ self.vocab_size = config.text_config.vocab_size
327
+ self.ignore_index = config.ignore_index if hasattr(config, 'ignore_index') else -100
328
+ self.begin_of_context_token_id = config.begin_of_context_token_id
329
+ self.end_of_context_token_id = config.end_of_context_token_id
330
+
331
+ self.post_init()
332
+
333
+ def get_input_embeddings(self):
334
+ return self.language_model.get_input_embeddings()
335
+
336
+ def set_input_embeddings(self, value):
337
+ self.language_model.set_input_embeddings(value)
338
+
339
+ def get_output_embeddings(self):
340
+ return self.language_model.get_output_embeddings()
341
+
342
+ def set_output_embeddings(self, new_embeddings):
343
+ self.language_model.set_output_embeddings(new_embeddings)
344
+
345
+ def set_decoder(self, decoder):
346
+ self.language_model.set_decoder(decoder)
347
+
348
+ def get_decoder(self):
349
+ return self.language_model.get_decoder()
350
+
351
+ def tie_weights(self):
352
+ return self.language_model.tie_weights()
353
+
354
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
355
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
356
+ # update vocab size
357
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
358
+ self.vocab_size = model_embeds.num_embeddings
359
+ return model_embeds
360
+
361
+ def _merge_context_features(
362
+ self,
363
+ context_features,
364
+ inputs_embeds,
365
+ input_ids,
366
+ attention_mask,
367
+ position_ids=None,
368
+ labels=None,
369
+ context_attention_mask=None,
370
+ ):
371
+ batch_size, seq_length, embed_dim = inputs_embeds.shape
372
+ context_seq_len = context_features.size(1)
373
+
374
+ # Create embeddings for begin and end of context tokens
375
+ begin_context_embed = self.get_input_embeddings()(torch.tensor(self.begin_of_context_token_id, device=context_features.device))
376
+ end_context_embed = self.get_input_embeddings()(torch.tensor(self.end_of_context_token_id, device=context_features.device))
377
+
378
+ # Determine the actual lengths of context sequences (excluding padding)
379
+ if context_attention_mask is not None:
380
+ # context_attention_mask: [batch_size, context_seq_len, 1]
381
+ context_attention_mask = context_attention_mask.squeeze(-1) # [batch_size, context_seq_len]
382
+ # Sum over sequence length to get actual lengths
383
+ context_lengths = context_attention_mask.sum(dim=1).long() # [batch_size]
384
+ else:
385
+ # If no context_attention_mask is provided, assume full length
386
+ context_lengths = torch.full((batch_size,), context_seq_len, device=context_features.device, dtype=torch.long)
387
+ context_attention_mask = torch.ones(batch_size, context_seq_len, device=context_features.device, dtype=torch.long)
388
+
389
+ # Rearrange context features to include padding at the beginning
390
+ # Identify the maximum context length (excluding padding)
391
+ max_context_length = context_lengths.max().item()
392
+ # Calculate the amount of padding needed for each sample
393
+ padding_lengths = context_seq_len - context_lengths # [batch_size]
394
+
395
+ # Create new context_features with padding at the beginning
396
+ new_context_features = []
397
+ for i in range(batch_size):
398
+ padding_len = padding_lengths[i].item()
399
+ # Create padding embeddings (zeros)
400
+ padding_embed = torch.zeros(padding_len, embed_dim, device=context_features.device)
401
+ # Get actual context features (excluding padding)
402
+ actual_context = context_features[i, padding_len:context_seq_len]
403
+ # Concatenate padding, begin token, actual context, end token
404
+ sample_context = torch.cat([
405
+ padding_embed,
406
+ begin_context_embed.unsqueeze(0),
407
+ actual_context,
408
+ end_context_embed.unsqueeze(0)
409
+ ], dim=0) # [context_seq_len + 2, embed_dim]
410
+ new_context_features.append(sample_context)
411
+ # Stack to create [batch_size, new_context_seq_len, embed_dim]
412
+ context_features = torch.stack(new_context_features, dim=0)
413
+ new_context_seq_len = context_features.size(1)
414
+
415
+ # Update context_attention_mask accordingly
416
+ new_context_attention_mask = []
417
+ for i in range(batch_size):
418
+ padding_len = padding_lengths[i].item()
419
+ # Create padding mask (zeros)
420
+ padding_mask = torch.zeros(padding_len, device=context_features.device, dtype=attention_mask.dtype)
421
+ # Begin and end token masks
422
+ begin_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
423
+ end_attention = torch.ones(1, device=context_features.device, dtype=attention_mask.dtype)
424
+ # Actual context attention mask (excluding padding)
425
+ actual_mask = context_attention_mask[i, padding_len:context_seq_len]
426
+ # Concatenate masks
427
+ sample_mask = torch.cat([
428
+ padding_mask,
429
+ begin_attention,
430
+ actual_mask,
431
+ end_attention
432
+ ], dim=0) # [context_seq_len + 2]
433
+ new_context_attention_mask.append(sample_mask)
434
+ # Stack to create [batch_size, new_context_seq_len]
435
+ context_attention_mask = torch.stack(new_context_attention_mask, dim=0)
436
+
437
+ # Concatenate context features with input embeddings
438
+ new_inputs_embeds = torch.cat([context_features, inputs_embeds], dim=1) # [batch_size, total_seq_len, embed_dim]
439
+
440
+ # Concatenate attention masks
441
+ new_attention_mask = torch.cat([context_attention_mask, attention_mask], dim=1)
442
+
443
+ # Create new position_ids
444
+ total_seq_len = new_inputs_embeds.size(1)
445
+ new_position_ids = torch.arange(total_seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
446
+
447
+ # Update labels if provided
448
+ if labels is not None:
449
+ # Create ignore labels for context (including padding and special tokens)
450
+ context_labels = torch.full((batch_size, new_context_seq_len), self.ignore_index, device=labels.device, dtype=labels.dtype)
451
+ new_labels = torch.cat([context_labels, labels], dim=1)
452
+ else:
453
+ new_labels = None
454
+
455
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
456
+
457
+
458
+ @replace_return_docstrings(output_type=CoEncoderCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
459
+ def forward(
460
+ self,
461
+ input_ids: torch.LongTensor = None,
462
+ context_input_ids: torch.LongTensor = None,
463
+ context_attention_mask: Optional[torch.Tensor] = None,
464
+ attention_mask: Optional[torch.Tensor] = None,
465
+ position_ids: Optional[torch.LongTensor] = None,
466
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
467
+ inputs_embeds: Optional[torch.FloatTensor] = None,
468
+ labels: Optional[torch.LongTensor] = None,
469
+ use_cache: Optional[bool] = None,
470
+ output_attentions: Optional[bool] = None,
471
+ output_hidden_states: Optional[bool] = None,
472
+ return_dict: Optional[bool] = None,
473
+ ) -> Union[Tuple, CoEncoderCausalLMOutputWithPast]:
474
+ """
475
+ Perform a forward pass through the CoEncoder model, optionally conditioning on context input.
476
+
477
+ Args:
478
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
479
+ Token IDs of the input sequence.
480
+ context_input_ids (`torch.LongTensor` of shape `(batch_size, context_sequence_length)`, *optional*):
481
+ Token IDs of the context input sequence.
482
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
483
+ Mask to avoid performing attention on padding token indices.
484
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
485
+ Indices of positions of each input sequence token.
486
+ past_key_values (`List[torch.FloatTensor]`, *optional*):
487
+ Pre-computed hidden-states (key and value tensors) that can be used to speed up sequential decoding.
488
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
489
+ Optionally, instead of passing `input_ids`, you can pass an embedded representation directly.
490
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
491
+ Labels for computing the language modeling loss.
492
+ use_cache (`bool`, *optional*):
493
+ If `True`, past key values will be used to speed up decoding.
494
+ output_attentions (`bool`, *optional*):
495
+ If `True`, return the attention tensors for each layer.
496
+ output_hidden_states (`bool`, *optional*):
497
+ If `True`, return the hidden states of all layers.
498
+ return_dict (`bool`, *optional*):
499
+ If `True`, return a `CoEncoderCausalLMOutputWithPast` instead of a plain tuple.
500
+
501
+ Returns:
502
+ `Union[Tuple, CoEncoderCausalLMOutputWithPast]`: A tuple containing various model outputs or a `CoEncoderCausalLMOutputWithPast` instance.
503
+ """
504
+
505
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
506
+ output_hidden_states = (
507
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
508
+ )
509
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
510
+
511
+ # Process context input through ContextTower
512
+ if context_input_ids is not None:
513
+ context_features = self.context_tower(context_input_ids)
514
+ context_features, context_attention_mask = self.connector(context_features)
515
+ else:
516
+ context_features = None
517
+ context_attention_mask = None
518
+
519
+ if inputs_embeds is None:
520
+ inputs_embeds = self.get_input_embeddings()(input_ids)
521
+
522
+ if context_features is not None:
523
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_context_features(
524
+ context_features,
525
+ inputs_embeds,
526
+ input_ids,
527
+ attention_mask,
528
+ position_ids,
529
+ labels,
530
+ context_attention_mask=context_attention_mask,
531
+ )
532
+
533
+ outputs = self.language_model(
534
+ attention_mask=attention_mask,
535
+ position_ids=position_ids,
536
+ past_key_values=past_key_values,
537
+ inputs_embeds=inputs_embeds,
538
+ use_cache=use_cache,
539
+ output_attentions=output_attentions,
540
+ output_hidden_states=output_hidden_states,
541
+ return_dict=return_dict,
542
+ )
543
+
544
+ logits = outputs[0]
545
+
546
+ loss = None
547
+ if labels is not None:
548
+ shift_logits = logits[..., :-1, :].contiguous()
549
+ shift_labels = labels[..., 1:].contiguous()
550
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
551
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
552
+
553
+ if not return_dict:
554
+ output = (logits,) + outputs[1:]
555
+ return (loss,) + output if loss is not None else output
556
+
557
+ return CoEncoderCausalLMOutputWithPast(
558
+ loss=loss,
559
+ logits=logits,
560
+ past_key_values=outputs.past_key_values,
561
+ hidden_states=outputs.hidden_states,
562
+ attentions=outputs.attentions,
563
+ context_hidden_states=context_features,
564
+ )
565
+
566
+ def prepare_inputs_for_generation(
567
+ self,
568
+ input_ids,
569
+ past_key_values=None,
570
+ attention_mask=None,
571
+ inputs_embeds=None,
572
+ context_features=None,
573
+ **kwargs
574
+ ):
575
+ if past_key_values:
576
+ input_ids = input_ids[:, -1:]
577
+
578
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
579
+ if inputs_embeds is not None and past_key_values is None:
580
+ model_inputs = {"inputs_embeds": inputs_embeds}
581
+ else:
582
+ model_inputs = {"input_ids": input_ids}
583
+
584
+ model_inputs.update(
585
+ {
586
+ "past_key_values": past_key_values,
587
+ "use_cache": kwargs.get("use_cache"),
588
+ "attention_mask": attention_mask,
589
+ "context_features": context_features,
590
+ }
591
+ )
592
+ return model_inputs
tokenization_co_encoder.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Tokenization classes for CoEncoder"""
3
+
4
+ from typing import List, Union, Optional
5
+ from transformers import AutoTokenizer
6
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
7
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
8
+ from transformers.utils import logging
9
+ from transformers.feature_extraction_utils import BatchFeature
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class CoEncoderDualTokenizerKwargs(ProcessingKwargs, total=False):
15
+ _defaults = {
16
+ "context_kwargs": {
17
+ "padding": False,
18
+ },
19
+ "text_kwargs": {
20
+ "padding": False,
21
+ },
22
+ }
23
+
24
+
25
+ class CoEncoderDualTokenizer(ProcessorMixin):
26
+ r"""
27
+ CoEncoderDualTokenizer is tokenizer for the CoEncoder model. It processes context and main text.
28
+
29
+ Args:
30
+ context_tokenizer ([`PreTrainedTokenizer`]):
31
+ The tokenizer for context.
32
+ text_tokenizer ([`PreTrainedTokenizer`]):
33
+ The tokenizer for main text.
34
+ """
35
+
36
+ attributes = ["context_tokenizer", "text_tokenizer"]
37
+ context_tokenizer_class = "AutoTokenizer"
38
+ text_tokenizer_class = "AutoTokenizer"
39
+
40
+ def __init__(self, context_tokenizer=None, text_tokenizer=None):
41
+ super().__init__(context_tokenizer, text_tokenizer)
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
45
+ """
46
+ Load both context and text tokenizers from a given repository.
47
+
48
+ Args:
49
+ pretrained_model_name_or_path (str): The name or path of the Hugging Face repository.
50
+
51
+ Returns:
52
+ CoEncoderDualTokenizer: An instance of the tokenizer class.
53
+ """
54
+ # Load context_tokenizer from 'context_tokenizer' directory
55
+ context_tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/context_tokenizer", **kwargs)
56
+
57
+ # Load text_tokenizer from 'text_tokenizer' directory
58
+ text_tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/text_tokenizer", **kwargs)
59
+
60
+ # Return a new instance of CoEncoderDualTokenizer with both tokenizers loaded
61
+ return cls(context_tokenizer=context_tokenizer, text_tokenizer=text_tokenizer)
62
+
63
+ def __call__(
64
+ self,
65
+ context: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
66
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
67
+ return_tensors: Optional[str] = None,
68
+ **kwargs: Unpack[CoEncoderDualTokenizerKwargs]
69
+ ) -> BatchFeature:
70
+ """
71
+ Main method to prepare inputs for the CoEncoder model.
72
+
73
+ Args:
74
+ context: Context text input.
75
+ text: Main text input.
76
+ return_tensors: Type of tensors to return.
77
+
78
+ Returns:
79
+ BatchFeature: A BatchFeature object containing model inputs.
80
+ """
81
+ if context is None and text is None:
82
+ raise ValueError("You must provide either context or text.")
83
+
84
+ features = {}
85
+
86
+ if context is not None:
87
+ context_features = self.context_tokenizer(
88
+ context,
89
+ return_tensors=return_tensors,
90
+ **kwargs.get("context_kwargs", {})
91
+ )
92
+ features.update({f"context_{k}": v for k, v in context_features.items()})
93
+
94
+ if text is not None:
95
+ text_features = self.text_tokenizer(
96
+ text,
97
+ return_tensors=return_tensors,
98
+ **kwargs.get("text_kwargs", {})
99
+ )
100
+ features.update({k: v for k, v in text_features.items()})
101
+
102
+ return BatchFeature(data=features, tensor_type=return_tensors)
103
+
104
+ def pad(
105
+ self,
106
+ encoded_inputs,
107
+ padding=True,
108
+ max_length=None,
109
+ return_tensors=None,
110
+ **kwargs
111
+ ):
112
+ """
113
+ Pads the encoded inputs to the maximum length in the batch.
114
+
115
+ Args:
116
+ encoded_inputs: A list of dictionaries containing context and text features.
117
+ padding: Whether to pad sequences.
118
+ max_length: Maximum length for padding.
119
+ return_tensors: Type of tensors to return.
120
+
121
+ Returns:
122
+ A dictionary with padded sequences.
123
+ """
124
+ # Separate context and text features
125
+ context_features = []
126
+ text_features = []
127
+
128
+ for feature in encoded_inputs:
129
+ # Extract context features
130
+ context_feature = {
131
+ k[len("context_"):]: v
132
+ for k, v in feature.items()
133
+ if k.startswith("context_")
134
+ }
135
+ if context_feature:
136
+ context_features.append(context_feature)
137
+ # Extract text features
138
+ text_feature = {
139
+ k[len("input_"):]: v
140
+ for k, v in feature.items()
141
+ if k.startswith("input_")
142
+ }
143
+ if text_feature:
144
+ text_features.append(text_feature)
145
+
146
+ # Pad context features
147
+ if context_features:
148
+ context_padded = self.context_tokenizer.pad(
149
+ context_features,
150
+ padding=padding,
151
+ max_length=max_length,
152
+ return_tensors=return_tensors,
153
+ **kwargs.get("context_kwargs", {})
154
+ )
155
+ context_padded = {f"context_{k}": v for k, v in context_padded.items()}
156
+ else:
157
+ context_padded = {}
158
+
159
+ # Pad text features
160
+ if text_features:
161
+ text_padded = self.text_tokenizer.pad(
162
+ text_features,
163
+ padding=padding,
164
+ max_length=max_length,
165
+ return_tensors=return_tensors,
166
+ **kwargs.get("text_kwargs", {})
167
+ )
168
+ text_padded = {k: v for k, v in text_padded.items()}
169
+ else:
170
+ text_padded = {}
171
+
172
+ # Combine padded features
173
+ padded_features = {**context_padded, **text_padded}
174
+
175
+ return BatchFeature(data=padded_features, tensor_type=return_tensors)
176
+
177
+ def batch_decode(self, *args, **kwargs):
178
+ """
179
+ Calls the batch_decode method of the text_tokenizer.
180
+ """
181
+ return self.text_tokenizer.batch_decode(*args, **kwargs)
182
+
183
+ def decode(self, *args, **kwargs):
184
+ """
185
+ Calls the decode method of the text_tokenizer.
186
+ """
187
+ return self.text_tokenizer.decode(*args, **kwargs)
188
+
189
+ @property
190
+ def model_input_names(self):
191
+ """
192
+ Returns the model input names.
193
+ """
194
+ return list(dict.fromkeys(self.context_tokenizer.model_input_names + self.text_tokenizer.model_input_names))