Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from .configs.model2path import MODEL2PATH | |
class MInferenceConfig: | |
ATTENTION_TYPES = [ | |
"minference", | |
"minference_with_dense", | |
"static", | |
"dilated1", | |
"dilated2", | |
"streaming", | |
"inf_llm", | |
"vllm", | |
] | |
def __init__( | |
self, | |
attn_type: str = "minference", | |
model_name: str = None, | |
config_path: str = None, | |
starting_layer: int = -1, | |
kv_cache_cpu: bool = False, | |
use_snapkv: bool = False, | |
is_search: bool = False, | |
attn_kwargs: dict = {}, | |
**kwargs, | |
): | |
super(MInferenceConfig, self).__init__() | |
assert ( | |
attn_type in self.ATTENTION_TYPES | |
), f"The attention_type {attn_type} you specified is not supported." | |
self.attn_type = attn_type | |
self.config_path = self.update_config_path(config_path, model_name) | |
self.model_name = model_name | |
self.is_search = is_search | |
self.starting_layer = starting_layer | |
self.kv_cache_cpu = kv_cache_cpu | |
self.use_snapkv = use_snapkv | |
self.attn_kwargs = attn_kwargs | |
def update_config_path(self, config_path: str, model_name: str): | |
if config_path is not None: | |
return config_path | |
assert ( | |
model_name in MODEL2PATH | |
), f"The model {model_name} you specified is not supported. You are welcome to add it and open a PR :)" | |
return MODEL2PATH[model_name] | |