| from ._base import * | |
| from ..nn_modules.fused_gptj_attn import FusedGPTJAttentionForQuantizedModel | |
| class GPTJGPTQForCausalLM(BaseGPTQForCausalLM): | |
| layer_type = "GPTJBlock" | |
| layers_block_name = "transformer.h" | |
| outside_layer_modules = ["transformer.wte", "transformer.ln_f"] | |
| inside_layer_modules = [ | |
| ["attn.k_proj", "attn.v_proj", "attn.q_proj"], | |
| ["attn.out_proj"], | |
| ["mlp.fc_in"], | |
| ["mlp.fc_out"] | |
| ] | |
| fused_attn_module_type = FusedGPTJAttentionForQuantizedModel | |
| __all__ = ["GPTJGPTQForCausalLM"] | |