Text Classification
Transformers
Safetensors
mistral
feature-extraction
reward_model
custom_code
text-generation-inference

Enable flash_attention_2 support since the underlying Mistral model supports it

#3
Files changed (1) hide show
  1. modeling_eurus_rm.py +2 -0
modeling_eurus_rm.py CHANGED
@@ -5,6 +5,8 @@ from typing import Optional, List
5
 
6
  class EurusRewardModel(PreTrainedModel):
7
  config_class = MistralConfig
 
 
8
  def __init__(self, config):
9
  super().__init__(config)
10
  self.model = MistralModel(config)
 
5
 
6
  class EurusRewardModel(PreTrainedModel):
7
  config_class = MistralConfig
8
+ _supports_flash_attn_2 = True
9
+
10
  def __init__(self, config):
11
  super().__init__(config)
12
  self.model = MistralModel(config)