`FlashAttention is not installed` error on Windows 11
I am using windows 11 and successfully installed flash-attn
show in the following pic. But still get this RuntimeError: FlashAttention is not installed
error. So it does not support Windows if I want to use flash-attention?
Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that theRuntimeError: FlashAttention is not installed
exception is raised if you failed to runfrom flash_attn.ops.triton.rotary import apply_rotary
.
This line requires both flash attention and triton.
So, I guess you should also install the triton by runningpip install triton
Error is as below
"name": "RuntimeError",
"message": "FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 2
1 print(len(chunks))
----> 2 chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
3 # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
4 top_k = min(3, len(chunks))
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:623, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
620 features.update(extra_features)
622 with torch.no_grad():
--> 623 out_features = self.forward(features, **kwargs)
624 if self.device.type == \"hpu\":
625 out_features = copy.deepcopy(out_features)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:690, in SentenceTransformer.forward(self, input, **kwargs)
688 module_kwarg_keys = self.module_kwargs.get(module_name, [])
689 module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 690 input = module(input, **module_kwargs)
691 return input
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\jina-embeddings-v3\\30996fea06f69ecd8382ee4f11e29acaf6b5405e\\custom_st.py:143, in Transformer.forward(self, features, task)
139 lora_arguments = (
140 {\"adapter_mask\": adapter_mask} if adapter_mask is not None else {}
141 )
142 features.pop('prompt_length', None)
--> 143 output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
144 output_tokens = output_states[0]
145 features.update({\"token_embeddings\": output_tokens, \"attention_mask\": features[\"attention_mask\"]})
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_lora.py:370, in XLMRobertaLoRA.forward(self, *args, **kwargs)
369 def forward(self, *args, **kwargs):
--> 370 return self.roberta(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:709, in XLMRobertaModel.forward(self, input_ids, position_ids, token_type_ids, attention_mask, masked_tokens_mask, return_dict, **kwargs)
706 else:
707 subset_mask = None
--> 709 sequence_output = self.encoder(
710 hidden_states,
711 key_padding_mask=attention_mask,
712 subset_mask=subset_mask,
713 adapter_mask=adapter_mask,
714 )
716 if masked_tokens_mask is None:
717 pooled_output = (
718 self.pooler(sequence_output, adapter_mask=adapter_mask)
719 if self.pooler is not None
720 else None
721 )
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:241, in XLMRobertaEncoder.forward(self, hidden_states, key_padding_mask, subset_mask, adapter_mask)
234 hidden_states = torch.utils.checkpoint.checkpoint(
235 layer,
236 hidden_states,
237 use_reentrant=self.use_reentrant,
238 mixer_kwargs=mixer_kwargs,
239 )
240 else:
--> 241 hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
242 hidden_states = pad_input(hidden_states, indices, batch, seqlen)
243 else:
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\block.py:201, in Block.forward(self, hidden_states, residual, mixer_subset, mixer_kwargs)
199 else:
200 assert residual is None
--> 201 mixer_out = self.mixer(
202 hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
203 )
204 if self.return_residual: # mixer out is actually a pair here
205 mixer_out, hidden_states = mixer_out
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\mha.py:732, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, adapter_mask, **kwargs)
725 if (
726 inference_params is None
727 or inference_params.seqlen_offset == 0
728 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
729 or not self.use_flash_attn
730 ):
731 if self.rotary_emb_dim > 0:
--> 732 qkv = self.rotary_emb(
733 qkv,
734 seqlen_offset=seqlen_offset,
735 cu_seqlens=cu_seqlens,
736 max_seqlen=rotary_max_seqlen,
737 )
738 if inference_params is None:
739 if not self.checkpointing:
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:604, in RotaryEmbedding.forward(self, qkv, kv, seqlen_offset, cu_seqlens, max_seqlen)
602 if kv is None:
603 if self.scale is None:
--> 604 return apply_rotary_emb_qkv_(
605 qkv,
606 self._cos_cached,
607 self._sin_cached,
608 interleaved=self.interleaved,
609 seqlen_offsets=seqlen_offset,
610 cu_seqlens=cu_seqlens,
611 max_seqlen=max_seqlen,
612 use_flash_attn=self.use_flash_attn,
613 )
614 else:
615 return apply_rotary_emb_qkv_(
616 qkv,
617 self._cos_cached,
(...)
625 use_flash_attn=self.use_flash_attn,
626 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:327, in apply_rotary_emb_qkv_(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
297 def apply_rotary_emb_qkv_(
298 qkv,
299 cos,
(...)
307 use_flash_attn=True,
308 ):
309 \"\"\"
310 Arguments:
311 qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
(...)
325 Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326 \"\"\"
--> 327 return ApplyRotaryEmbQKV_.apply(
328 qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329 )
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\\autograd\\function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 \"In order to use an autograd.Function with functorch transforms \"
580 \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
581 \"staticmethod. For more details, please see \"
582 \"https://pytorch.org/docs/main/notes/extending.func.html\"
583 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:186, in ApplyRotaryEmbQKV_.forward(ctx, qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
184 qk = rearrange(qkv[..., :2, :, :], \"... t h d -> ... (t h) d\")
185 # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
--> 186 apply_rotary(
187 qk,
188 cos,
189 sin,
190 seqlen_offsets=seqlen_offsets,
191 interleaved=interleaved,
192 inplace=True,
193 cu_seqlens=cu_seqlens,
194 max_seqlen=max_seqlen,
195 )
196 else:
197 q_rot = apply_rotary_emb_torch(
198 qkv[:, :, 0],
199 cos,
200 sin,
201 interleaved=interleaved,
202 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:18, in apply_rotary(*args, **kwargs)
17 def apply_rotary(*args, **kwargs):
---> 18 raise RuntimeError(
19 \"FlashAttention is not installed. To proceed with training, please install FlashAttention. \"
20 \"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.\"
21 )
RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.
</div> ```
Hi
@ocean11
, afaik flash-attention
is not fully supported on windows, and that's why you're having this issue. You can disable it by setting use_flash_attn=False
when loading the model.
How to set in sentence transformers? It seems there is no place to put use_flash_attn=False
.
Previously if I do not install flash attention, the model will just give out several lines of warning not using it.
import torch
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
chunks= [...]
queries = [...]
# Use "convert_to_tensor=True" to keep the tensors on GPU (if available)
chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity
top_k = min(3, len(chunks))
for query in queries:
query_embedding = embedder.encode(query, convert_to_tensor=True)
# We use cosine-similarity and torch.topk to find the highest 5 scores
similarity_scores = embedder.similarity(query_embedding, chunks_embeddings)[0]
scores, indices = torch.topk(similarity_scores, k=top_k)
print("\n\033[91m--- Query:", query, "---\033[0m")
print("Top 3 most similar sentences in corpus:")
...
For SentenceTransformers you can set it like this:embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, model_kwargs={'use_flash_attn': False})
For SentenceTransformers you can set it like this:
embedder = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True, model_kwargs={'use_flash_attn': False})
Problem solved as suggested. Thanks a lot!