Team Finetuner commited on
Commit
344bcbc
1 Parent(s): 5ee2c37

chore: update from afe81ca705ca1a5bd6b7d90548fcac068850b2af

Browse files
Files changed (2) hide show
  1. configuration_bert.py +1 -7
  2. modeling_bert.py +0 -3
configuration_bert.py CHANGED
@@ -84,14 +84,10 @@ class JinaBertConfig(PretrainedConfig):
84
  emb_pooler (`str`, *optional*, defaults to `None`):
85
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
86
  Should be one of `None`, `"mean"`.
87
- with_flash (`bool`, *optional*, defaults to `False`):
88
- Whether to use triton flash attention. Only works for `triton==2.0.0.dev20230208`.
89
- This argument will be deprecated in the future. Use `attention_implementation` instead.
90
- attn_implementation (`str`, *optional*, defaults to `None`):
91
  The implementation of the self-attention layer. Can be one of:
92
  - `None` for the original implementation,
93
  - `torch` for the PyTorch SDPA implementation,
94
- - `triton` for the Triton Flash implementation. Only works for `triton==2.0.0.dev20230208`
95
 
96
  Examples:
97
 
@@ -132,7 +128,6 @@ class JinaBertConfig(PretrainedConfig):
132
  classifier_dropout=None,
133
  feed_forward_type="original",
134
  emb_pooler=None,
135
- with_flash=False,
136
  attn_implementation='torch',
137
  **kwargs,
138
  ):
@@ -156,7 +151,6 @@ class JinaBertConfig(PretrainedConfig):
156
  self.feed_forward_type = feed_forward_type
157
  self.emb_pooler = emb_pooler
158
  self.attn_implementation = attn_implementation
159
- self.with_flash = with_flash
160
 
161
  class JinaBertOnnxConfig(OnnxConfig):
162
  @property
 
84
  emb_pooler (`str`, *optional*, defaults to `None`):
85
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
86
  Should be one of `None`, `"mean"`.
87
+ attn_implementation (`str`, *optional*, defaults to `"torch"`):
 
 
 
88
  The implementation of the self-attention layer. Can be one of:
89
  - `None` for the original implementation,
90
  - `torch` for the PyTorch SDPA implementation,
 
91
 
92
  Examples:
93
 
 
128
  classifier_dropout=None,
129
  feed_forward_type="original",
130
  emb_pooler=None,
 
131
  attn_implementation='torch',
132
  **kwargs,
133
  ):
 
151
  self.feed_forward_type = feed_forward_type
152
  self.emb_pooler = emb_pooler
153
  self.attn_implementation = attn_implementation
 
154
 
155
  class JinaBertOnnxConfig(OnnxConfig):
156
  @property
modeling_bert.py CHANGED
@@ -273,9 +273,6 @@ class JinaBertSelfAttention(nn.Module):
273
  )
274
 
275
  self.attn_implementation = config.attn_implementation
276
- if config.with_flash:
277
- self.attn_implementation = 'triton'
278
-
279
  self.num_attention_heads = config.num_attention_heads
280
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
281
  self.all_head_size = self.num_attention_heads * self.attention_head_size
 
273
  )
274
 
275
  self.attn_implementation = config.attn_implementation
 
 
 
276
  self.num_attention_heads = config.num_attention_heads
277
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
278
  self.all_head_size = self.num_attention_heads * self.attention_head_size