hvlgo commited on
Commit
729e0ea
1 Parent(s): 544dbb8

Update modeling_timer.py

Browse files
Files changed (1) hide show
  1. modeling_timer.py +572 -565
modeling_timer.py CHANGED
@@ -1,565 +1,572 @@
1
- from typing import Optional, Tuple, List, Union
2
- import torch
3
- from torch import nn
4
- import torch.nn.functional as F
5
- from transformers import PreTrainedModel, Cache, DynamicCache
6
- from transformers.activations import ACT2FN
7
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
8
- from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
9
- from .configuration_timer import TimerConfig
10
- from .ts_generation_mixin import TSGenerationMixin
11
-
12
-
13
- def rotate_half(x):
14
- x1 = x[..., : x.shape[-1] // 2]
15
- x2 = x[..., x.shape[-1] // 2:]
16
- return torch.cat((-x2, x1), dim=-1)
17
-
18
-
19
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
20
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
21
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
22
- q_embed = (q * cos) + (rotate_half(q) * sin)
23
- k_embed = (k * cos) + (rotate_half(k) * sin)
24
- return q_embed, k_embed
25
-
26
-
27
- class TimerPatchEmbedding(nn.Module):
28
- def __init__(self, config: TimerConfig):
29
- super().__init__()
30
- self.input_token_len = config.input_token_len
31
- self.emb = nn.Linear(config.input_token_len,
32
- config.hidden_size, bias=False)
33
-
34
- def forward(self, hidden_state: torch.Tensor):
35
- hidden_state = hidden_state.unfold(
36
- dimension=-1, size=self.input_token_len, step=self.input_token_len)
37
- return self.emb(hidden_state)
38
-
39
-
40
- class TimerPointEmbedding(nn.Module):
41
- def __init__(self, config: TimerConfig):
42
- super().__init__()
43
- self.emb_layer = nn.Linear(
44
- config.input_token_len, config.hidden_size, bias=False)
45
- self.gate_layer = nn.Linear(
46
- config.input_token_len, config.hidden_size, bias=False)
47
- self.act_fn = ACT2FN[config.hidden_act]
48
-
49
- def forward(self, x):
50
- emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x)
51
- return emb
52
-
53
-
54
- class TimeMoeRotaryEmbedding(torch.nn.Module):
55
- def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
56
- super().__init__()
57
- self.dim = dim
58
- self.max_position_embeddings = max_position_embeddings
59
- self.base = base
60
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
61
- 2, dtype=torch.int64).float().to(device) / self.dim))
62
- self.register_buffer("inv_freq", inv_freq, persistent=False)
63
-
64
- # Build here to make `torch.jit.trace` work.
65
- self._set_cos_sin_cache(
66
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
67
- )
68
-
69
- def _set_cos_sin_cache(self, seq_len, device, dtype):
70
- self.max_seq_len_cached = seq_len
71
- t = torch.arange(self.max_seq_len_cached, device=device,
72
- dtype=torch.int64).type_as(self.inv_freq)
73
-
74
- freqs = torch.outer(t, self.inv_freq)
75
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
76
- emb = torch.cat((freqs, freqs), dim=-1)
77
- self.register_buffer(
78
- "cos_cached", emb.cos().to(dtype), persistent=False)
79
- self.register_buffer(
80
- "sin_cached", emb.sin().to(dtype), persistent=False)
81
-
82
- def forward(self, x, seq_len=None):
83
- # x: [bs, num_attention_heads, seq_len, head_size]
84
- if seq_len > self.max_seq_len_cached:
85
- self._set_cos_sin_cache(
86
- seq_len=seq_len, device=x.device, dtype=x.dtype)
87
-
88
- return (
89
- self.cos_cached[:seq_len].to(dtype=x.dtype),
90
- self.sin_cached[:seq_len].to(dtype=x.dtype),
91
- )
92
-
93
-
94
- class TimerAttention(nn.Module):
95
- def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None):
96
- super().__init__()
97
- self.layer_idx = layer_idx
98
- self.hidden_size = config.hidden_size
99
- self.num_heads = config.num_attention_heads
100
- self.head_dim = self.hidden_size // self.num_heads
101
- self.attention_dropout = config.attention_dropout
102
- self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
103
- self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
104
- self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
105
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
106
- self.rotary_emb = TimeMoeRotaryEmbedding(
107
- self.head_dim, max_position_embeddings=config.max_position_embeddings)
108
-
109
- def forward(
110
- self,
111
- hidden_states: torch.Tensor,
112
- attention_mask: Optional[torch.Tensor] = None,
113
- position_ids: Optional[torch.LongTensor] = None,
114
- past_key_value: Optional[Cache] = None,
115
- output_attentions: bool = False,
116
- **kwargs,
117
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
118
- bsz, q_len, _ = hidden_states.size()
119
-
120
- query_states = self.q_proj(hidden_states)
121
- key_states = self.k_proj(hidden_states)
122
- value_states = self.v_proj(hidden_states)
123
-
124
- query_states = query_states.view(
125
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
126
- key_states = key_states.view(
127
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
128
- value_states = value_states.view(
129
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
130
-
131
- kv_seq_len = key_states.shape[-2]
132
- if past_key_value is not None:
133
- kv_seq_len += past_key_value.get_usable_length(
134
- kv_seq_len, self.layer_idx)
135
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
136
- query_states, key_states = apply_rotary_pos_emb(
137
- query_states, key_states, cos, sin, position_ids)
138
-
139
- if past_key_value is not None:
140
- key_states, value_states = past_key_value.update(
141
- key_states, value_states, self.layer_idx)
142
-
143
- attn_output = F.scaled_dot_product_attention(
144
- query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout)
145
-
146
- attn_output = attn_output.transpose(1, 2).contiguous()
147
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
148
- attn_output = self.o_proj(attn_output)
149
-
150
- if not output_attentions:
151
- attn_weights = None
152
-
153
- return attn_output, attn_weights, past_key_value
154
-
155
-
156
- class TimerMLP(nn.Module):
157
- def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
158
- super().__init__()
159
- self.hidden_size = hidden_size
160
- self.intermediate_size = intermediate_size
161
- self.gate_proj = nn.Linear(
162
- self.hidden_size, self.intermediate_size, bias=False)
163
- self.up_proj = nn.Linear(
164
- self.hidden_size, self.intermediate_size, bias=False)
165
- self.down_proj = nn.Linear(
166
- self.intermediate_size, self.hidden_size, bias=False)
167
- self.act_fn = ACT2FN[hidden_act]
168
-
169
- def forward(self, hidden_state):
170
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
171
-
172
-
173
- class TimerDecoderLayer(nn.Module):
174
- def __init__(self, config: TimerConfig, layer_idx: int):
175
- super().__init__()
176
- self.self_attn = TimerAttention(config, layer_idx)
177
-
178
- self.ffn_layer = TimerMLP(
179
- hidden_size=config.hidden_size,
180
- intermediate_size=config.intermediate_size,
181
- hidden_act=config.hidden_act,
182
- )
183
- self.norm1 = torch.nn.LayerNorm(config.hidden_size)
184
- self.norm2 = torch.nn.LayerNorm(config.hidden_size)
185
-
186
- def forward(
187
- self,
188
- hidden_states: torch.Tensor,
189
- attention_mask: Optional[torch.Tensor] = None,
190
- position_ids: Optional[torch.LongTensor] = None,
191
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
192
- output_attentions: Optional[bool] = False,
193
- use_cache: Optional[bool] = False,
194
- **kwargs,
195
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
196
- residual = hidden_states
197
-
198
- # Self Attention
199
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
200
- hidden_states=hidden_states,
201
- attention_mask=attention_mask,
202
- position_ids=position_ids,
203
- past_key_value=past_key_value,
204
- output_attentions=output_attentions,
205
- use_cache=use_cache,
206
- )
207
- hidden_states = residual + hidden_states
208
- hidden_states = self.norm1(hidden_states)
209
-
210
- # Fully Connected
211
- residual = hidden_states
212
- hidden_states = self.ffn_layer(hidden_states)
213
- hidden_states = residual + hidden_states
214
- hidden_states = self.norm2(hidden_states)
215
-
216
- if not output_attentions:
217
- self_attn_weights = None
218
-
219
- if not use_cache:
220
- present_key_value = None
221
- return hidden_states, self_attn_weights, present_key_value
222
-
223
-
224
- class TimerPreTrainedModel(PreTrainedModel):
225
- config_class = TimerConfig
226
- base_model_prefix = "model"
227
- supports_gradient_checkpointing = True
228
- _no_split_modules = ["TimeMoeDecoderLayer"]
229
- _skip_keys_device_placement = "past_key_values"
230
- _supports_flash_attn_2 = True
231
- _supports_sdpa = False
232
- _supports_cache_class = True
233
-
234
- def _init_weights(self, module):
235
- std = self.config.initializer_range
236
- if isinstance(module, torch.nn.Linear):
237
- module.weight.data.normal_(mean=0.0, std=std)
238
- if module.bias is not None:
239
- module.bias.data.zero_()
240
- elif isinstance(module, torch.nn.Embedding):
241
- module.weight.data.normal_(mean=0.0, std=std)
242
- if module.padding_idx is not None:
243
- module.weight.data[module.padding_idx].zero_()
244
-
245
-
246
- class TimerModel(TimerPreTrainedModel):
247
- def __init__(self, config: TimerConfig):
248
- super().__init__(config)
249
- self.embed_layer = TimerPatchEmbedding(config)
250
- self.layers = nn.ModuleList(
251
- [TimerDecoderLayer(config, layer_idx)
252
- for layer_idx in range(config.num_hidden_layers)]
253
- )
254
- self.norm = torch.nn.LayerNorm(config.hidden_size)
255
- self.gradient_checkpointing = False
256
-
257
- def forward(
258
- self,
259
- input_ids: torch.FloatTensor = None,
260
- attention_mask: Optional[torch.Tensor] = None,
261
- position_ids: Optional[torch.LongTensor] = None,
262
- past_key_values: Optional[List[torch.FloatTensor]] = None,
263
- inputs_embeds: Optional[torch.FloatTensor] = None,
264
- use_cache: Optional[bool] = None,
265
- output_attentions: Optional[bool] = None,
266
- output_hidden_states: Optional[bool] = None,
267
- return_dict: Optional[bool] = None,
268
- ) -> Union[Tuple, MoeModelOutputWithPast]:
269
- # input_ids is the input of time series, its shape is [batch_size, seq_len]
270
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271
- output_hidden_states = (
272
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
- )
274
- use_cache = use_cache if use_cache is not None else self.config.use_cache
275
-
276
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
-
278
- # retrieve input_ids and inputs_embeds
279
- if input_ids is not None and inputs_embeds is not None:
280
- raise ValueError(
281
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
282
- elif input_ids is not None:
283
- batch_size, seq_length = input_ids.shape
284
- elif inputs_embeds is not None:
285
- batch_size, seq_length, _ = inputs_embeds.shape
286
- else:
287
- raise ValueError(
288
- "You have to specify either decoder_input_ids or decoder_inputs_embeds")
289
-
290
- if inputs_embeds is None:
291
- inputs_embeds = self.embed_layer(input_ids)
292
- seq_length = inputs_embeds.shape[1]
293
-
294
- if self.gradient_checkpointing and self.training:
295
- if use_cache:
296
- use_cache = False
297
-
298
- past_key_values_length = 0
299
-
300
- if use_cache:
301
- use_legacy_cache = not isinstance(past_key_values, Cache)
302
- if use_legacy_cache:
303
- past_key_values = DynamicCache.from_legacy_cache(
304
- past_key_values)
305
- past_key_values_length = past_key_values.get_usable_length(
306
- seq_length)
307
-
308
- if position_ids is None:
309
- device = input_ids.device if input_ids is not None else inputs_embeds.device
310
- position_ids = torch.arange(
311
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
312
- )
313
- # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
314
- position_ids = position_ids.view(-1, seq_length)
315
- else:
316
- position_ids = position_ids.view(-1, seq_length).long()
317
-
318
- # 4d mask is passed through the layers
319
- attention_mask = _prepare_4d_causal_attention_mask(
320
- attention_mask,
321
- (batch_size, seq_length),
322
- inputs_embeds,
323
- past_key_values_length,
324
- sliding_window=None,
325
- )
326
-
327
- hidden_states = inputs_embeds
328
-
329
- # decoder layers
330
- all_hidden_states = () if output_hidden_states else None
331
- all_self_attns = () if output_attentions else None
332
- next_decoder_cache = None
333
-
334
- for decoder_layer in self.layers:
335
- if output_hidden_states:
336
- all_hidden_states += (hidden_states,)
337
-
338
- if self.gradient_checkpointing and self.training:
339
- layer_outputs = self._gradient_checkpointing_func(
340
- decoder_layer.__call__,
341
- hidden_states,
342
- attention_mask,
343
- position_ids,
344
- past_key_values,
345
- output_attentions,
346
- use_cache,
347
- )
348
- else:
349
- layer_outputs = decoder_layer(
350
- hidden_states,
351
- attention_mask=attention_mask,
352
- position_ids=position_ids,
353
- past_key_value=past_key_values,
354
- output_attentions=output_attentions,
355
- use_cache=use_cache,
356
- )
357
-
358
- hidden_states = layer_outputs[0]
359
-
360
- if output_attentions:
361
- all_self_attns += (layer_outputs[1],)
362
-
363
- if use_cache:
364
- next_decoder_cache = layer_outputs[2]
365
-
366
- hidden_states = self.norm(hidden_states)
367
- # add hidden states from the last decoder layer
368
- if output_hidden_states:
369
- all_hidden_states += (hidden_states,)
370
-
371
- next_cache = None
372
- if use_cache:
373
- next_cache = next_decoder_cache.to_legacy_cache(
374
- ) if use_legacy_cache else next_decoder_cache
375
-
376
- if not return_dict:
377
- return tuple(
378
- v
379
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
380
- if v is not None
381
- )
382
- return MoeModelOutputWithPast(
383
- last_hidden_state=hidden_states,
384
- past_key_values=next_cache,
385
- hidden_states=all_hidden_states,
386
- attentions=all_self_attns,
387
- )
388
-
389
-
390
- class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
391
- def __init__(self, config: TimerConfig):
392
- super().__init__(config)
393
- self.config = config
394
- self.model = TimerModel(self.config)
395
- lm_head_list = []
396
- self.output_token_len_map = {}
397
- for i, output_token_len in enumerate(self.config.output_token_lens):
398
- lm_head_list.append(
399
- nn.Linear(self.config.hidden_size, output_token_len, bias=False))
400
- self.output_token_len_map[output_token_len] = i
401
- self.lm_heads = nn.ModuleList(lm_head_list)
402
- self.loss_function = torch.nn.MSELoss(reduction='none')
403
- self.post_init()
404
-
405
- def set_decoder(self, decoder):
406
- self.model = decoder
407
-
408
- def get_decoder(self):
409
- return self.model
410
-
411
- def forward(
412
- self,
413
- input_ids: torch.FloatTensor = None,
414
- attention_mask: Optional[torch.Tensor] = None,
415
- position_ids: Optional[torch.LongTensor] = None,
416
- past_key_values: Optional[List[torch.FloatTensor]] = None,
417
- inputs_embeds: Optional[torch.FloatTensor] = None,
418
- labels: Optional[torch.FloatTensor] = None,
419
- loss_masks: Optional[torch.FloatTensor] = None,
420
- use_cache: Optional[bool] = None,
421
- output_attentions: Optional[bool] = None,
422
- output_hidden_states: Optional[bool] = None,
423
- return_dict: Optional[bool] = None,
424
- max_output_length: Optional[int] = None,
425
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
426
-
427
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
428
- output_hidden_states = (
429
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
430
- )
431
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
432
-
433
- outputs = self.model(
434
- input_ids=input_ids,
435
- attention_mask=attention_mask,
436
- position_ids=position_ids,
437
- past_key_values=past_key_values,
438
- inputs_embeds=inputs_embeds,
439
- use_cache=use_cache,
440
- output_attentions=output_attentions,
441
- output_hidden_states=output_hidden_states,
442
- return_dict=return_dict,
443
- )
444
-
445
- hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
446
- predictions = None
447
-
448
- loss = None
449
- if labels is not None:
450
- ar_loss = 0.0
451
- for lm_head, output_token_len in zip(self.lm_heads, self.config.output_token_lens):
452
- one_predictions = lm_head(hidden_states)
453
- one_loss = self.calc_ar_loss(
454
- one_predictions, labels, loss_masks, output_token_len)
455
- ar_loss += one_loss
456
- if predictions is None:
457
- predictions = one_predictions
458
- loss = ar_loss / len(self.config.output_token_lens)
459
- else:
460
- if max_output_length is None:
461
- output_token_len = self.config.output_token_lens[0]
462
- max_output_length = output_token_len
463
- else:
464
- output_token_len = self.config.output_token_lens[0]
465
- for h in self.config.output_token_lens[1:]:
466
- if h > max_output_length:
467
- break
468
- else:
469
- output_token_len = h
470
- lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
471
- predictions = lm_head(hidden_states)
472
- if output_token_len > max_output_length:
473
- predictions = predictions[:, :, :max_output_length]
474
- if not return_dict:
475
- output = (predictions,) + outputs[1:]
476
- return (loss) + output if loss is not None else output
477
-
478
- return MoeCausalLMOutputWithPast(
479
- loss=loss,
480
- logits=predictions,
481
- past_key_values=outputs.past_key_values,
482
- hidden_states=outputs.hidden_states,
483
- attentions=outputs.attentions,
484
- )
485
-
486
- def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len):
487
- seq_len = predictions.shape[1] * self.config.input_token_len
488
- labels = labels[:, :seq_len -
489
- self.config.input_token_len + output_token_len]
490
- shift_labels = labels.unfold(
491
- dimension=-1, size=output_token_len, step=self.config.input_token_len)
492
-
493
- # Calculate loss with mask
494
- losses = self.loss_function(predictions, shift_labels).mean(dim=-1)
495
- if loss_masks is not None:
496
- losses = losses * loss_masks
497
- loss = losses.sum() / loss_masks.sum()
498
- else:
499
- loss = torch.mean(losses)
500
-
501
- return loss
502
-
503
- def prepare_inputs_for_generation(
504
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
505
- ):
506
- # Omit tokens covered by past_key_values
507
- if past_key_values is not None:
508
- if isinstance(past_key_values, Cache):
509
- cache_length = past_key_values.get_seq_length()
510
- if isinstance(past_key_values, DynamicCache):
511
- past_length = past_key_values.seen_tokens
512
- else:
513
- past_length = cache_length
514
-
515
- max_cache_length = past_key_values.get_max_length()
516
- else:
517
- cache_length = past_length = past_key_values[0][0].shape[2]
518
- max_cache_length = None
519
-
520
- # Keep only the unprocessed tokens:
521
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
522
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
523
- # input)
524
- if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len):
525
- input_ids = input_ids[:, -
526
- (attention_mask.shape[1] - past_length):]
527
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
528
- # input_ids based on the past_length.
529
- elif past_length < (input_ids.shape[1] // self.config.input_token_len):
530
- input_ids = input_ids[:, past_length *
531
- self.config.input_token_len:]
532
- # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
533
-
534
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
535
- if (
536
- max_cache_length is not None
537
- and attention_mask is not None
538
- and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
539
- ):
540
- attention_mask = attention_mask[:, -max_cache_length:]
541
-
542
- position_ids = kwargs.get("position_ids", None)
543
- if attention_mask is not None and position_ids is None:
544
- # create position_ids on the fly for batch generation
545
- position_ids = attention_mask.long().cumsum(-1) - 1
546
- position_ids.masked_fill_(attention_mask == 0, 1)
547
- if past_key_values:
548
- position_ids = position_ids[:, -
549
- (input_ids.shape[1] // self.config.input_token_len):]
550
-
551
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
552
- if inputs_embeds is not None and past_key_values is None:
553
- model_inputs = {"inputs_embeds": inputs_embeds}
554
- else:
555
- model_inputs = {"input_ids": input_ids}
556
-
557
- model_inputs.update(
558
- {
559
- "position_ids": position_ids,
560
- "past_key_values": past_key_values,
561
- "use_cache": kwargs.get("use_cache"),
562
- "attention_mask": attention_mask,
563
- }
564
- )
565
- return model_inputs
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List, Union
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, Cache, DynamicCache
6
+ from transformers.activations import ACT2FN
7
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
8
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
9
+ from .configuration_timer import TimerConfig
10
+ from .ts_generation_mixin import TSGenerationMixin
11
+
12
+
13
+ def rotate_half(x):
14
+ x1 = x[..., : x.shape[-1] // 2]
15
+ x2 = x[..., x.shape[-1] // 2:]
16
+ return torch.cat((-x2, x1), dim=-1)
17
+
18
+
19
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
20
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
21
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
22
+ q_embed = (q * cos) + (rotate_half(q) * sin)
23
+ k_embed = (k * cos) + (rotate_half(k) * sin)
24
+ return q_embed, k_embed
25
+
26
+
27
+ class TimerPatchEmbedding(nn.Module):
28
+ def __init__(self, config: TimerConfig):
29
+ super().__init__()
30
+ self.input_token_len = config.input_token_len
31
+ self.emb = nn.Linear(config.input_token_len,
32
+ config.hidden_size, bias=False)
33
+
34
+ def forward(self, hidden_state: torch.Tensor):
35
+ hidden_state = hidden_state.unfold(
36
+ dimension=-1, size=self.input_token_len, step=self.input_token_len)
37
+ return self.emb(hidden_state)
38
+
39
+
40
+ class TimerPointEmbedding(nn.Module):
41
+ def __init__(self, config: TimerConfig):
42
+ super().__init__()
43
+ self.emb_layer = nn.Linear(
44
+ config.input_token_len, config.hidden_size, bias=False)
45
+ self.gate_layer = nn.Linear(
46
+ config.input_token_len, config.hidden_size, bias=False)
47
+ self.act_fn = ACT2FN[config.hidden_act]
48
+
49
+ def forward(self, x):
50
+ emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x)
51
+ return emb
52
+
53
+
54
+ class TimeMoeRotaryEmbedding(torch.nn.Module):
55
+ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.max_position_embeddings = max_position_embeddings
59
+ self.base = base
60
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
61
+ 2, dtype=torch.int64).float().to(device) / self.dim))
62
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
63
+
64
+ # Build here to make `torch.jit.trace` work.
65
+ self._set_cos_sin_cache(
66
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
67
+ )
68
+
69
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
70
+ self.max_seq_len_cached = seq_len
71
+ t = torch.arange(self.max_seq_len_cached, device=device,
72
+ dtype=torch.int64).type_as(self.inv_freq)
73
+
74
+ freqs = torch.outer(t, self.inv_freq)
75
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
76
+ emb = torch.cat((freqs, freqs), dim=-1)
77
+ self.register_buffer(
78
+ "cos_cached", emb.cos().to(dtype), persistent=False)
79
+ self.register_buffer(
80
+ "sin_cached", emb.sin().to(dtype), persistent=False)
81
+
82
+ def forward(self, x, seq_len=None):
83
+ # x: [bs, num_attention_heads, seq_len, head_size]
84
+ if seq_len > self.max_seq_len_cached:
85
+ self._set_cos_sin_cache(
86
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
87
+
88
+ return (
89
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
90
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
91
+ )
92
+
93
+
94
+ class TimerAttention(nn.Module):
95
+ def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None):
96
+ super().__init__()
97
+ self.layer_idx = layer_idx
98
+ self.hidden_size = config.hidden_size
99
+ self.num_heads = config.num_attention_heads
100
+ self.head_dim = self.hidden_size // self.num_heads
101
+ self.attention_dropout = config.attention_dropout
102
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
103
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
104
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
105
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
106
+ self.rotary_emb = TimeMoeRotaryEmbedding(
107
+ self.head_dim, max_position_embeddings=config.max_position_embeddings)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.Tensor,
112
+ attention_mask: Optional[torch.Tensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ past_key_value: Optional[Cache] = None,
115
+ output_attentions: bool = False,
116
+ **kwargs,
117
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
118
+ bsz, q_len, _ = hidden_states.size()
119
+
120
+ query_states = self.q_proj(hidden_states)
121
+ key_states = self.k_proj(hidden_states)
122
+ value_states = self.v_proj(hidden_states)
123
+
124
+ query_states = query_states.view(
125
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
126
+ key_states = key_states.view(
127
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
128
+ value_states = value_states.view(
129
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
130
+
131
+ kv_seq_len = key_states.shape[-2]
132
+ if past_key_value is not None:
133
+ kv_seq_len += past_key_value.get_usable_length(
134
+ kv_seq_len, self.layer_idx)
135
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
136
+ query_states, key_states = apply_rotary_pos_emb(
137
+ query_states, key_states, cos, sin, position_ids)
138
+
139
+ if past_key_value is not None:
140
+ key_states, value_states = past_key_value.update(
141
+ key_states, value_states, self.layer_idx)
142
+
143
+ attn_output = F.scaled_dot_product_attention(
144
+ query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout)
145
+
146
+ attn_output = attn_output.transpose(1, 2).contiguous()
147
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
148
+ attn_output = self.o_proj(attn_output)
149
+
150
+ if not output_attentions:
151
+ attn_weights = None
152
+
153
+ return attn_output, attn_weights, past_key_value
154
+
155
+
156
+ class TimerMLP(nn.Module):
157
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
158
+ super().__init__()
159
+ self.hidden_size = hidden_size
160
+ self.intermediate_size = intermediate_size
161
+ self.gate_proj = nn.Linear(
162
+ self.hidden_size, self.intermediate_size, bias=False)
163
+ self.up_proj = nn.Linear(
164
+ self.hidden_size, self.intermediate_size, bias=False)
165
+ self.down_proj = nn.Linear(
166
+ self.intermediate_size, self.hidden_size, bias=False)
167
+ self.act_fn = ACT2FN[hidden_act]
168
+
169
+ def forward(self, hidden_state):
170
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
171
+
172
+
173
+ class TimerDecoderLayer(nn.Module):
174
+ def __init__(self, config: TimerConfig, layer_idx: int):
175
+ super().__init__()
176
+ self.self_attn = TimerAttention(config, layer_idx)
177
+
178
+ self.ffn_layer = TimerMLP(
179
+ hidden_size=config.hidden_size,
180
+ intermediate_size=config.intermediate_size,
181
+ hidden_act=config.hidden_act,
182
+ )
183
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
184
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
185
+
186
+ def forward(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ position_ids: Optional[torch.LongTensor] = None,
191
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
192
+ output_attentions: Optional[bool] = False,
193
+ use_cache: Optional[bool] = False,
194
+ **kwargs,
195
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
196
+ residual = hidden_states
197
+
198
+ # Self Attention
199
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
200
+ hidden_states=hidden_states,
201
+ attention_mask=attention_mask,
202
+ position_ids=position_ids,
203
+ past_key_value=past_key_value,
204
+ output_attentions=output_attentions,
205
+ use_cache=use_cache,
206
+ )
207
+ hidden_states = residual + hidden_states
208
+ hidden_states = self.norm1(hidden_states)
209
+
210
+ # Fully Connected
211
+ residual = hidden_states
212
+ hidden_states = self.ffn_layer(hidden_states)
213
+ hidden_states = residual + hidden_states
214
+ hidden_states = self.norm2(hidden_states)
215
+
216
+ if not output_attentions:
217
+ self_attn_weights = None
218
+
219
+ if not use_cache:
220
+ present_key_value = None
221
+ return hidden_states, self_attn_weights, present_key_value
222
+
223
+
224
+ class TimerPreTrainedModel(PreTrainedModel):
225
+ config_class = TimerConfig
226
+ base_model_prefix = "model"
227
+ supports_gradient_checkpointing = True
228
+ _no_split_modules = ["TimeMoeDecoderLayer"]
229
+ _skip_keys_device_placement = "past_key_values"
230
+ _supports_flash_attn_2 = True
231
+ _supports_sdpa = False
232
+ _supports_cache_class = True
233
+
234
+ def _init_weights(self, module):
235
+ std = self.config.initializer_range
236
+ if isinstance(module, torch.nn.Linear):
237
+ module.weight.data.normal_(mean=0.0, std=std)
238
+ if module.bias is not None:
239
+ module.bias.data.zero_()
240
+ elif isinstance(module, torch.nn.Embedding):
241
+ module.weight.data.normal_(mean=0.0, std=std)
242
+ if module.padding_idx is not None:
243
+ module.weight.data[module.padding_idx].zero_()
244
+
245
+
246
+ class TimerModel(TimerPreTrainedModel):
247
+ def __init__(self, config: TimerConfig):
248
+ super().__init__(config)
249
+ self.embed_layer = TimerPatchEmbedding(config)
250
+ self.layers = nn.ModuleList(
251
+ [TimerDecoderLayer(config, layer_idx)
252
+ for layer_idx in range(config.num_hidden_layers)]
253
+ )
254
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
255
+ self.gradient_checkpointing = False
256
+
257
+ def forward(
258
+ self,
259
+ input_ids: torch.FloatTensor = None,
260
+ attention_mask: Optional[torch.Tensor] = None,
261
+ position_ids: Optional[torch.LongTensor] = None,
262
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
263
+ inputs_embeds: Optional[torch.FloatTensor] = None,
264
+ use_cache: Optional[bool] = None,
265
+ output_attentions: Optional[bool] = None,
266
+ output_hidden_states: Optional[bool] = None,
267
+ return_dict: Optional[bool] = None,
268
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
269
+ # input_ids is the input of time series, its shape is [batch_size, seq_len]
270
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271
+ output_hidden_states = (
272
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
+ )
274
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
275
+
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+
278
+ # retrieve input_ids and inputs_embeds
279
+ if input_ids is not None and inputs_embeds is not None:
280
+ raise ValueError(
281
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
282
+ elif input_ids is not None:
283
+ batch_size, seq_length = input_ids.shape
284
+ elif inputs_embeds is not None:
285
+ batch_size, seq_length, _ = inputs_embeds.shape
286
+ else:
287
+ raise ValueError(
288
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
289
+
290
+ if inputs_embeds is None:
291
+ inputs_embeds = self.embed_layer(input_ids)
292
+ seq_length = inputs_embeds.shape[1]
293
+
294
+ if self.gradient_checkpointing and self.training:
295
+ if use_cache:
296
+ use_cache = False
297
+
298
+ past_key_values_length = 0
299
+
300
+ if use_cache:
301
+ use_legacy_cache = not isinstance(past_key_values, Cache)
302
+ if use_legacy_cache:
303
+ past_key_values = DynamicCache.from_legacy_cache(
304
+ past_key_values)
305
+ past_key_values_length = past_key_values.get_usable_length(
306
+ seq_length)
307
+
308
+ if position_ids is None:
309
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
310
+ position_ids = torch.arange(
311
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
312
+ )
313
+ # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
314
+ position_ids = position_ids.view(-1, seq_length)
315
+ else:
316
+ position_ids = position_ids.view(-1, seq_length).long()
317
+
318
+ # 4d mask is passed through the layers
319
+ attention_mask = _prepare_4d_causal_attention_mask(
320
+ attention_mask,
321
+ (batch_size, seq_length),
322
+ inputs_embeds,
323
+ past_key_values_length,
324
+ sliding_window=None,
325
+ )
326
+
327
+ hidden_states = inputs_embeds
328
+
329
+ # decoder layers
330
+ all_hidden_states = () if output_hidden_states else None
331
+ all_self_attns = () if output_attentions else None
332
+ next_decoder_cache = None
333
+
334
+ for decoder_layer in self.layers:
335
+ if output_hidden_states:
336
+ all_hidden_states += (hidden_states,)
337
+
338
+ if self.gradient_checkpointing and self.training:
339
+ layer_outputs = self._gradient_checkpointing_func(
340
+ decoder_layer.__call__,
341
+ hidden_states,
342
+ attention_mask,
343
+ position_ids,
344
+ past_key_values,
345
+ output_attentions,
346
+ use_cache,
347
+ )
348
+ else:
349
+ layer_outputs = decoder_layer(
350
+ hidden_states,
351
+ attention_mask=attention_mask,
352
+ position_ids=position_ids,
353
+ past_key_value=past_key_values,
354
+ output_attentions=output_attentions,
355
+ use_cache=use_cache,
356
+ )
357
+
358
+ hidden_states = layer_outputs[0]
359
+
360
+ if output_attentions:
361
+ all_self_attns += (layer_outputs[1],)
362
+
363
+ if use_cache:
364
+ next_decoder_cache = layer_outputs[2]
365
+
366
+ hidden_states = self.norm(hidden_states)
367
+ # add hidden states from the last decoder layer
368
+ if output_hidden_states:
369
+ all_hidden_states += (hidden_states,)
370
+
371
+ next_cache = None
372
+ if use_cache:
373
+ next_cache = next_decoder_cache.to_legacy_cache(
374
+ ) if use_legacy_cache else next_decoder_cache
375
+
376
+ if not return_dict:
377
+ return tuple(
378
+ v
379
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
380
+ if v is not None
381
+ )
382
+ return MoeModelOutputWithPast(
383
+ last_hidden_state=hidden_states,
384
+ past_key_values=next_cache,
385
+ hidden_states=all_hidden_states,
386
+ attentions=all_self_attns,
387
+ )
388
+
389
+
390
+ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
391
+ def __init__(self, config: TimerConfig):
392
+ super().__init__(config)
393
+ self.config = config
394
+ self.model = TimerModel(self.config)
395
+ lm_head_list = []
396
+ self.output_token_len_map = {}
397
+ for i, output_token_len in enumerate(self.config.output_token_lens):
398
+ lm_head_list.append(
399
+ nn.Linear(self.config.hidden_size, output_token_len, bias=False))
400
+ self.output_token_len_map[output_token_len] = i
401
+ self.lm_heads = nn.ModuleList(lm_head_list)
402
+ self.loss_function = torch.nn.MSELoss(reduction='none')
403
+ self.post_init()
404
+
405
+ def set_decoder(self, decoder):
406
+ self.model = decoder
407
+
408
+ def get_decoder(self):
409
+ return self.model
410
+
411
+ def forward(
412
+ self,
413
+ input_ids: torch.FloatTensor = None,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ position_ids: Optional[torch.LongTensor] = None,
416
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
417
+ inputs_embeds: Optional[torch.FloatTensor] = None,
418
+ labels: Optional[torch.FloatTensor] = None,
419
+ loss_masks: Optional[torch.FloatTensor] = None,
420
+ use_cache: Optional[bool] = None,
421
+ output_attentions: Optional[bool] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ max_output_length: Optional[int] = None,
425
+ revin: Optional[bool] = False,
426
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
427
+
428
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
429
+ output_hidden_states = (
430
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
431
+ )
432
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
433
+
434
+ if revin:
435
+ mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True)
436
+ input_ids = (input_ids - mean) / std
437
+ outputs = self.model(
438
+ input_ids=input_ids,
439
+ attention_mask=attention_mask,
440
+ position_ids=position_ids,
441
+ past_key_values=past_key_values,
442
+ inputs_embeds=inputs_embeds,
443
+ use_cache=use_cache,
444
+ output_attentions=output_attentions,
445
+ output_hidden_states=output_hidden_states,
446
+ return_dict=return_dict,
447
+ )
448
+
449
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
450
+ predictions = None
451
+
452
+ loss = None
453
+ if labels is not None:
454
+ ar_loss = 0.0
455
+ for lm_head, output_token_len in zip(self.lm_heads, self.config.output_token_lens):
456
+ one_predictions = lm_head(hidden_states)
457
+ one_loss = self.calc_ar_loss(
458
+ one_predictions, labels, loss_masks, output_token_len)
459
+ ar_loss += one_loss
460
+ if predictions is None:
461
+ predictions = one_predictions
462
+ loss = ar_loss / len(self.config.output_token_lens)
463
+ else:
464
+ if max_output_length is None:
465
+ output_token_len = self.config.output_token_lens[0]
466
+ max_output_length = output_token_len
467
+ else:
468
+ output_token_len = self.config.output_token_lens[0]
469
+ for h in self.config.output_token_lens[1:]:
470
+ if h > max_output_length:
471
+ break
472
+ else:
473
+ output_token_len = h
474
+ lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
475
+ predictions = lm_head(hidden_states)
476
+ if output_token_len > max_output_length:
477
+ predictions = predictions[:, :, :max_output_length]
478
+ if revin:
479
+ predictions = predictions * std + mean
480
+ if not return_dict:
481
+ output = (predictions,) + outputs[1:]
482
+ return (loss) + output if loss is not None else output
483
+
484
+ return MoeCausalLMOutputWithPast(
485
+ loss=loss,
486
+ logits=predictions,
487
+ past_key_values=outputs.past_key_values,
488
+ hidden_states=outputs.hidden_states,
489
+ attentions=outputs.attentions,
490
+ )
491
+
492
+ def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len):
493
+ seq_len = predictions.shape[1] * self.config.input_token_len
494
+ labels = labels[:, :seq_len -
495
+ self.config.input_token_len + output_token_len]
496
+ shift_labels = labels.unfold(
497
+ dimension=-1, size=output_token_len, step=self.config.input_token_len)
498
+
499
+ # Calculate loss with mask
500
+ losses = self.loss_function(predictions, shift_labels).mean(dim=-1)
501
+ if loss_masks is not None:
502
+ losses = losses * loss_masks
503
+ loss = losses.sum() / loss_masks.sum()
504
+ else:
505
+ loss = torch.mean(losses)
506
+
507
+ return loss
508
+
509
+ def prepare_inputs_for_generation(
510
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, revin=True, **kwargs
511
+ ):
512
+ # Omit tokens covered by past_key_values
513
+ if past_key_values is not None:
514
+ if isinstance(past_key_values, Cache):
515
+ cache_length = past_key_values.get_seq_length()
516
+ if isinstance(past_key_values, DynamicCache):
517
+ past_length = past_key_values.seen_tokens
518
+ else:
519
+ past_length = cache_length
520
+
521
+ max_cache_length = past_key_values.get_max_length()
522
+ else:
523
+ cache_length = past_length = past_key_values[0][0].shape[2]
524
+ max_cache_length = None
525
+
526
+ # Keep only the unprocessed tokens:
527
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
528
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
529
+ # input)
530
+ if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len):
531
+ input_ids = input_ids[:, -
532
+ (attention_mask.shape[1] - past_length):]
533
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
534
+ # input_ids based on the past_length.
535
+ elif past_length < (input_ids.shape[1] // self.config.input_token_len):
536
+ input_ids = input_ids[:, past_length *
537
+ self.config.input_token_len:]
538
+ # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens.
539
+
540
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
541
+ if (
542
+ max_cache_length is not None
543
+ and attention_mask is not None
544
+ and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length
545
+ ):
546
+ attention_mask = attention_mask[:, -max_cache_length:]
547
+
548
+ position_ids = kwargs.get("position_ids", None)
549
+ if attention_mask is not None and position_ids is None:
550
+ # create position_ids on the fly for batch generation
551
+ position_ids = attention_mask.long().cumsum(-1) - 1
552
+ position_ids.masked_fill_(attention_mask == 0, 1)
553
+ if past_key_values:
554
+ position_ids = position_ids[:, -
555
+ (input_ids.shape[1] // self.config.input_token_len):]
556
+
557
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
558
+ if inputs_embeds is not None and past_key_values is None:
559
+ model_inputs = {"inputs_embeds": inputs_embeds}
560
+ else:
561
+ model_inputs = {"input_ids": input_ids}
562
+
563
+ model_inputs.update(
564
+ {
565
+ "position_ids": position_ids,
566
+ "past_key_values": past_key_values,
567
+ "use_cache": kwargs.get("use_cache"),
568
+ "attention_mask": attention_mask,
569
+ "revin": revin
570
+ }
571
+ )
572
+ return model_inputs