Update _modeling_kormo.py
Browse files- _modeling_kormo.py +1 -7
 
    	
        _modeling_kormo.py
    CHANGED
    
    | 
         @@ -1,7 +1,6 @@ 
     | 
|
| 1 | 
         
             
            from typing import Callable, List, Optional, Tuple, Union
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
            -
            import torch.utils.checkpoint ### ADD
         
     | 
| 5 | 
         
             
            from torch import nn
         
     | 
| 6 | 
         | 
| 7 | 
         
             
            from transformers.activations import ACT2FN
         
     | 
| 
         @@ -17,13 +16,10 @@ from transformers.modeling_outputs import ( 
     | 
|
| 17 | 
         
             
            )
         
     | 
| 18 | 
         
             
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
         
     | 
| 19 | 
         
             
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
         
     | 
| 20 | 
         
            -
            # from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
         
     | 
| 21 | 
         
             
            from transformers.processing_utils import Unpack
         
     | 
| 22 | 
         
            -
            from transformers.utils import  
     | 
| 23 | 
         
             
            from ._configuration_kormo import KORMoConfig
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
            # from ._flash_attn3_doc import flash_attention_3_doc_forward
         
     | 
| 26 | 
         
            -
            # ALL_ATTENTION_FUNCTIONS._global_mapping.update({'flash_attention_3_doc': flash_attention_3_doc_forward})
         
     | 
| 27 | 
         | 
| 28 | 
         
             
            logger = logging.get_logger(__name__)
         
     | 
| 29 | 
         | 
| 
         @@ -421,8 +417,6 @@ class KORMoModel(KORMoPreTrainedModel): 
     | 
|
| 421 | 
         
             
                    )
         
     | 
| 422 | 
         | 
| 423 | 
         | 
| 424 | 
         
            -
            # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
         
     | 
| 425 | 
         
            -
             
     | 
| 426 | 
         
             
            class KORMoForCausalLM(KORMoPreTrainedModel, GenerationMixin):
         
     | 
| 427 | 
         
             
                _tied_weights_keys = ["lm_head.weight"]
         
     | 
| 428 | 
         
             
                _tp_plan = {"lm_head": "colwise_rep"}
         
     | 
| 
         | 
|
| 1 | 
         
             
            from typing import Callable, List, Optional, Tuple, Union
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 4 | 
         
             
            from torch import nn
         
     | 
| 5 | 
         | 
| 6 | 
         
             
            from transformers.activations import ACT2FN
         
     | 
| 
         | 
|
| 16 | 
         
             
            )
         
     | 
| 17 | 
         
             
            from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
         
     | 
| 18 | 
         
             
            from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
         
     | 
| 
         | 
|
| 19 | 
         
             
            from transformers.processing_utils import Unpack
         
     | 
| 20 | 
         
            +
            from transformers.utils import can_return_tuple, logging
         
     | 
| 21 | 
         
             
            from ._configuration_kormo import KORMoConfig
         
     | 
| 22 | 
         | 
| 
         | 
|
| 
         | 
|
| 23 | 
         | 
| 24 | 
         
             
            logger = logging.get_logger(__name__)
         
     | 
| 25 | 
         | 
| 
         | 
|
| 417 | 
         
             
                    )
         
     | 
| 418 | 
         | 
| 419 | 
         | 
| 
         | 
|
| 
         | 
|
| 420 | 
         
             
            class KORMoForCausalLM(KORMoPreTrainedModel, GenerationMixin):
         
     | 
| 421 | 
         
             
                _tied_weights_keys = ["lm_head.weight"]
         
     | 
| 422 | 
         
             
                _tp_plan = {"lm_head": "colwise_rep"}
         
     |