File size: 1,485 Bytes
43a7079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os

from .configs.model2path import MODEL2PATH


class MInferenceConfig:
    ATTENTION_TYPES = [
        "minference",
        "minference_with_dense",
        "static",
        "dilated1",
        "dilated2",
        "streaming",
        "inf_llm",
        "vllm",
    ]

    def __init__(
        self,
        attn_type: str = "minference",
        model_name: str = None,
        config_path: str = None,
        starting_layer: int = -1,
        kv_cache_cpu: bool = False,
        use_snapkv: bool = False,
        is_search: bool = False,
        attn_kwargs: dict = {},
        **kwargs,
    ):
        super(MInferenceConfig, self).__init__()
        assert (
            attn_type in self.ATTENTION_TYPES
        ), f"The attention_type {attn_type} you specified is not supported."
        self.attn_type = attn_type
        self.config_path = self.update_config_path(config_path, model_name)
        self.model_name = model_name
        self.is_search = is_search
        self.starting_layer = starting_layer
        self.kv_cache_cpu = kv_cache_cpu
        self.use_snapkv = use_snapkv
        self.attn_kwargs = attn_kwargs

    def update_config_path(self, config_path: str, model_name: str):
        if config_path is not None:
            return config_path
        assert (
            model_name in MODEL2PATH
        ), f"The model {model_name} you specified is not supported. You are welcome to add it and open a PR :)"
        return MODEL2PATH[model_name]