from transformers import LlamaConfig as OrigLlamaConfig | |
class LlamaConfig(OrigLlamaConfig): | |
model_type = "llama_aqlm" | |
def __init__( | |
self, | |
nbits_per_codebook: int = 16, | |
num_codebooks: int = 1, | |
out_group_size: int = 1, | |
in_group_size: int = 8, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.aqlm = { | |
"nbits_per_codebook": nbits_per_codebook, | |
"num_codebooks": num_codebooks, | |
"out_group_size": out_group_size, | |
"in_group_size": in_group_size, | |
} | |