| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from torch import Tensor | 
					
					
						
						| 
							 | 
						from dataclasses import dataclass, field | 
					
					
						
						| 
							 | 
						from typing import Optional | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class InferenceParams: | 
					
					
						
						| 
							 | 
						    """Inference parameters that are passed to the main model in order | 
					
					
						
						| 
							 | 
						    to efficienly calculate and store the context during inference.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    max_seqlen: int | 
					
					
						
						| 
							 | 
						    max_batch_size: int | 
					
					
						
						| 
							 | 
						    seqlen_offset: int = 0 | 
					
					
						
						| 
							 | 
						    batch_size_offset: int = 0 | 
					
					
						
						| 
							 | 
						    key_value_memory_dict: dict = field(default_factory=dict) | 
					
					
						
						| 
							 | 
						    lengths_per_sample: Optional[Tensor] = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def reset(self, max_seqlen, max_batch_size): | 
					
					
						
						| 
							 | 
						        self.max_seqlen = max_seqlen | 
					
					
						
						| 
							 | 
						        self.max_batch_size = max_batch_size | 
					
					
						
						| 
							 | 
						        self.seqlen_offset = 0 | 
					
					
						
						| 
							 | 
						        if self.lengths_per_sample is not None: | 
					
					
						
						| 
							 | 
						            self.lengths_per_sample.zero_() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class RecurrentInferenceParams: | 
					
					
						
						| 
							 | 
						    """Inference parameters passed to blocks with recurrent mode.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    fir_filter_length: int = 3 | 
					
					
						
						| 
							 | 
						    state_dim: int = 16 | 
					
					
						
						| 
							 | 
						    seqlen_offset: int = 0 | 
					
					
						
						| 
							 | 
						    fir_state_dict: dict = field(default_factory=dict) | 
					
					
						
						| 
							 | 
						    state_dict: dict = field(default_factory=dict) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def reset(self): | 
					
					
						
						| 
							 | 
						        self.fir_filter_length = 3 | 
					
					
						
						| 
							 | 
						        self.state_dim = 16 | 
					
					
						
						| 
							 | 
						        self.seqlen_offset = 0 | 
					
					
						
						| 
							 | 
						
 |