MInference / minference /models_patch.py
iofu728's picture
Feature(MInference): build demo
43a7079
raw history blame
No virus
3.45 kB
import os
from .minference_configuration import MInferenceConfig
from .patch import minference_patch, minference_patch_vllm, patch_hf
class MInference:
def __init__(
self,
attn_type: str = "minference",
model_name: str = None,
config_path: str = None,
starting_layer: int = -1,
kv_cache_cpu: bool = False,
use_snapkv: bool = False,
is_search: bool = False,
attn_kwargs: dict = {},
**kwargs,
):
super(MInference, self).__init__()
self.config = MInferenceConfig(
attn_type=attn_type,
model_name=model_name,
config_path=config_path,
starting_layer=starting_layer,
kv_cache_cpu=kv_cache_cpu,
use_snapkv=use_snapkv,
is_search=is_search,
attn_kwargs=attn_kwargs,
**kwargs,
)
def __call__(self, model):
return self.patch_model(model)
def patch_model(self, model):
if self.config.attn_type != "vllm":
model.config.starting_layer = self.config.starting_layer
model.config.config_path = self.config.config_path
if self.config.attn_type == "minference":
model.config.is_search = self.config.is_search
model = minference_patch(model, self.config)
elif self.config.attn_type == "minference_with_dense":
model.config.dense = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "dilated1":
model.config.dilated1 = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "static":
model.config.static_pattern = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "dilated2":
model.config.dilated2 = True
model = minference_patch(model, self.config)
elif self.config.attn_type == "streaming":
model.config.streaming = True
model.config.streaming_kwargs = {
"n_local": 3968,
"n_init": 128,
**self.config.attn_kwargs,
}
model = minference_patch(model, self.config)
elif self.config.attn_type == "streaming2":
model = patch_hf(
model,
attn_type="streaming",
attn_kwargs={"n_local": 3968, "n_init": 128, **self.config.attn_kwargs},
)
elif self.config.attn_type == "inf_llm":
model = patch_hf(
model,
attn_type="inf_llm",
attn_kwargs={
"block_size": 128,
"n_init": 128,
"n_local": 4096,
"topk": 16,
"repr_topk": 4,
"max_cached_block": 32,
"exc_block_size": 512,
"base": 1000000,
"distance_scale": 1.0,
"dense_decoding": True,
**self.config.attn_kwargs,
},
)
elif self.config.attn_type == "vllm":
model = minference_patch_vllm(model, self.config.config_path)
else:
raise ValueError(
f"The attention type {self.config.attn_type} you specified is not supported."
)
return model