LightGTS / ts_generation_mixin.py
pchen182224's picture
Upload 9 files
c882c3e verified
from typing import Any, Dict, List, Optional, Union, Callable
import torch
from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerationConfig, GenerateOutput
from transformers.utils import ModelOutput
class TSGenerationMixin(GenerationMixin):
@torch.no_grad()
def generate(self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
revin: Optional[bool] = True,
patch_len:Optional[int] = None,
stride_len:Optional[int]= None,
max_output_length:Optional[int] = None,
inference_patch_len: Optional[int] = None,
**kwargs,
) -> Union[GenerateOutput, torch.Tensor]:
if len(inputs.shape) != 3:
raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
if revin:
means = inputs.mean(dim=1, keepdim=True)
stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
inputs = (inputs - means) / stdev
model_inputs = {
"input" : inputs,
"patch_len" : patch_len,
"stride" : stride_len,
"target_dim" : max_output_length
}
outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]
outputs = outputs["prediction"]
if revin:
outputs = (outputs * stdev) + means
return outputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
horizon_length: int = 1,
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
return model_kwargs