aria-dev commited on
Commit
f5e1f0d
1 Parent(s): 006dc0f

support eager attention

Browse files
Files changed (3) hide show
  1. configuration_aria.py +13 -0
  2. modeling_aria.py +0 -1
  3. vision_encoder.py +1 -0
configuration_aria.py CHANGED
@@ -68,6 +68,8 @@ class AriaConfig(PretrainedConfig):
68
  self.ignore_index = ignore_index
69
  self.image_token_index = image_token_index
70
 
 
 
71
  # Convert the keys and values of projector_patch_to_query_dict to integers
72
  # This ensures consistency even if they were provided as strings
73
  self.projector_patch_to_query_dict = {
@@ -76,10 +78,21 @@ class AriaConfig(PretrainedConfig):
76
 
77
  if isinstance(vision_config, dict) and "model_type" in vision_config:
78
  vision_config = AriaVisionConfig(**vision_config)
 
 
 
 
 
 
79
 
80
  self.vision_config = vision_config
81
 
82
  if isinstance(text_config, dict) and "model_type" in text_config:
 
 
 
83
  text_config = AriaMoELMConfig(**text_config)
 
 
84
 
85
  self.text_config = text_config
 
68
  self.ignore_index = ignore_index
69
  self.image_token_index = image_token_index
70
 
71
+ attn_implementation = kwargs.pop("attn_implementation", None)
72
+
73
  # Convert the keys and values of projector_patch_to_query_dict to integers
74
  # This ensures consistency even if they were provided as strings
75
  self.projector_patch_to_query_dict = {
 
78
 
79
  if isinstance(vision_config, dict) and "model_type" in vision_config:
80
  vision_config = AriaVisionConfig(**vision_config)
81
+ vision_attn_implementation = (
82
+ "flash_attention_2"
83
+ if attn_implementation is None
84
+ else attn_implementation
85
+ )
86
+ vision_config._attn_implementation = vision_attn_implementation
87
 
88
  self.vision_config = vision_config
89
 
90
  if isinstance(text_config, dict) and "model_type" in text_config:
91
+ text_attn_implementation = (
92
+ "sdpa" if attn_implementation is None else attn_implementation
93
+ )
94
  text_config = AriaMoELMConfig(**text_config)
95
+ text_config._attn_implementation = text_attn_implementation
96
+ print(text_config._attn_implementation)
97
 
98
  self.text_config = text_config
modeling_aria.py CHANGED
@@ -133,7 +133,6 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
133
  def __init__(self, config: AriaConfig):
134
  super().__init__(config)
135
 
136
- config.vision_config._attn_implementation = config._attn_implementation
137
  self.vision_tower = AriaVisionModel(config.vision_config)
138
  self.multi_modal_projector = build_mm_projector(config)
139
  self.vocab_size = config.text_config.vocab_size
 
133
  def __init__(self, config: AriaConfig):
134
  super().__init__(config)
135
 
 
136
  self.vision_tower = AriaVisionModel(config.vision_config)
137
  self.multi_modal_projector = build_mm_projector(config)
138
  self.vocab_size = config.text_config.vocab_size
vision_encoder.py CHANGED
@@ -82,6 +82,7 @@ class AriaVisionModel(SiglipVisionModel):
82
 
83
  config_class = AriaVisionConfig
84
  main_input_name = "pixel_values"
 
85
 
86
  def __init__(self, config: AriaVisionConfig):
87
  super().__init__(config)
 
82
 
83
  config_class = AriaVisionConfig
84
  main_input_name = "pixel_values"
85
+ _supports_sdpa = False
86
 
87
  def __init__(self, config: AriaVisionConfig):
88
  super().__init__(config)