Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from .minference_configuration import MInferenceConfig | |
from .patch import minference_patch, minference_patch_vllm, patch_hf | |
class MInference: | |
def __init__( | |
self, | |
attn_type: str = "minference", | |
model_name: str = None, | |
config_path: str = None, | |
starting_layer: int = -1, | |
kv_cache_cpu: bool = False, | |
use_snapkv: bool = False, | |
is_search: bool = False, | |
attn_kwargs: dict = {}, | |
**kwargs, | |
): | |
super(MInference, self).__init__() | |
self.config = MInferenceConfig( | |
attn_type=attn_type, | |
model_name=model_name, | |
config_path=config_path, | |
starting_layer=starting_layer, | |
kv_cache_cpu=kv_cache_cpu, | |
use_snapkv=use_snapkv, | |
is_search=is_search, | |
attn_kwargs=attn_kwargs, | |
**kwargs, | |
) | |
def __call__(self, model): | |
return self.patch_model(model) | |
def patch_model(self, model): | |
if self.config.attn_type != "vllm": | |
model.config.starting_layer = self.config.starting_layer | |
model.config.config_path = self.config.config_path | |
if self.config.attn_type == "minference": | |
model.config.is_search = self.config.is_search | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "minference_with_dense": | |
model.config.dense = True | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "dilated1": | |
model.config.dilated1 = True | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "static": | |
model.config.static_pattern = True | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "dilated2": | |
model.config.dilated2 = True | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "streaming": | |
model.config.streaming = True | |
model.config.streaming_kwargs = { | |
"n_local": 3968, | |
"n_init": 128, | |
**self.config.attn_kwargs, | |
} | |
model = minference_patch(model, self.config) | |
elif self.config.attn_type == "streaming2": | |
model = patch_hf( | |
model, | |
attn_type="streaming", | |
attn_kwargs={"n_local": 3968, "n_init": 128, **self.config.attn_kwargs}, | |
) | |
elif self.config.attn_type == "inf_llm": | |
model = patch_hf( | |
model, | |
attn_type="inf_llm", | |
attn_kwargs={ | |
"block_size": 128, | |
"n_init": 128, | |
"n_local": 4096, | |
"topk": 16, | |
"repr_topk": 4, | |
"max_cached_block": 32, | |
"exc_block_size": 512, | |
"base": 1000000, | |
"distance_scale": 1.0, | |
"dense_decoding": True, | |
**self.config.attn_kwargs, | |
}, | |
) | |
elif self.config.attn_type == "vllm": | |
model = minference_patch_vllm(model, self.config.config_path) | |
else: | |
raise ValueError( | |
f"The attention type {self.config.attn_type} you specified is not supported." | |
) | |
return model | |