Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,449 Bytes
43a7079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
from .minference_configuration import MInferenceConfig
from .patch import minference_patch, minference_patch_vllm, patch_hf
class MInference:
def __init__(
self,
attn_type: str = "minference",
model_name: str = None,
config_path: str = None,
starting_layer: int = -1,
kv_cache_cpu: bool = False,
use_snapkv: bool = False,
is_search: bool = False,
attn_kwargs: dict = {},
**kwargs,
):
super(MInference, self).__init__()
self.config = MInferenceConfig(
attn_type=attn_type,
model_name=model_name,
config_path=config_path,
starting_layer=starting_layer,
kv_cache_cpu=kv_cache_cpu,
use_snapkv=use_snapkv,
is_search=is_search,
attn_kwargs=attn_kwargs,
**kwargs,
)
def __call__(self, model):
return self.patch_model(model)
def patch_model(self, model):
if self.config.attn_type != "vllm":
model.config.starting_layer = self.config.starting_layer
model.config.config_path = self.config.config_path
if self.config.attn_type == "minference":
model.config.is_search = self.config.is_search
model = minference_patch(model, self.config)
elif self.config.attn_type == "minference_with_dense":
model.config.dense = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "dilated1":
model.config.dilated1 = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "static":
model.config.static_pattern = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "dilated2":
model.config.dilated2 = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "streaming":
model.config.streaming = True
model.config.streaming_kwargs = {
"n_local": 3968,
"n_init": 128,
**self.config.attn_kwargs,
}
model = minference_patch(model, self.config)
elif self.config.attn_type == "streaming2":
model = patch_hf(
model,
attn_type="streaming",
attn_kwargs={"n_local": 3968, "n_init": 128, **self.config.attn_kwargs},
)
elif self.config.attn_type == "inf_llm":
model = patch_hf(
model,
attn_type="inf_llm",
attn_kwargs={
"block_size": 128,
"n_init": 128,
"n_local": 4096,
"topk": 16,
"repr_topk": 4,
"max_cached_block": 32,
"exc_block_size": 512,
"base": 1000000,
"distance_scale": 1.0,
"dense_decoding": True,
**self.config.attn_kwargs,
},
)
elif self.config.attn_type == "vllm":
model = minference_patch_vllm(model, self.config.config_path)
else:
raise ValueError(
f"The attention type {self.config.attn_type} you specified is not supported."
)
return model
|