# Copyright (c) Together | |
# This software is distributed under the terms of the Apache License, Version 2.0 | |
# Author: Michael Poli | |
from torch import Tensor | |
from dataclasses import dataclass, field | |
from typing import Optional | |
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py | |
class InferenceParams: | |
"""Inference parameters that are passed to the main model in order | |
to efficienly calculate and store the context during inference.""" | |
max_seqlen: int | |
max_batch_size: int | |
seqlen_offset: int = 0 | |
batch_size_offset: int = 0 | |
key_value_memory_dict: dict = field(default_factory=dict) | |
lengths_per_sample: Optional[Tensor] = None | |
def reset(self, max_seqlen, max_batch_size): | |
self.max_seqlen = max_seqlen | |
self.max_batch_size = max_batch_size | |
self.seqlen_offset = 0 | |
if self.lengths_per_sample is not None: | |
self.lengths_per_sample.zero_() | |
class RecurrentInferenceParams: | |
"""Inference parameters passed to blocks with recurrent mode.""" | |
fir_filter_length: int = 3 | |
state_dim: int = 16 | |
seqlen_offset: int = 0 | |
fir_state_dict: dict = field(default_factory=dict) | |
state_dict: dict = field(default_factory=dict) | |
def reset(self): | |
self.fir_filter_length = 3 | |
self.state_dim = 16 | |
self.seqlen_offset = 0 | |