Andrei Panferov commited on
Commit
f48478c
1 Parent(s): 03ea233

try except flash-attn

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. modeling_llama_aqlm.py +6 -3
config.json CHANGED
@@ -24,6 +24,7 @@
24
  "tf_legacy_loss": false,
25
  "pruned_heads": {},
26
  "tie_word_embeddings": false,
 
27
  "is_encoder_decoder": false,
28
  "is_decoder": false,
29
  "cross_attention_hidden_size": null,
@@ -46,7 +47,6 @@
46
  "encoder_no_repeat_ngram_size": 0,
47
  "bad_words_ids": null,
48
  "num_return_sequences": 1,
49
- "chunk_size_feed_forward": 0,
50
  "output_scores": false,
51
  "return_dict_in_generate": false,
52
  "forced_bos_token_id": null,
@@ -77,7 +77,7 @@
77
  "task_specific_params": null,
78
  "problem_type": null,
79
  "_name_or_path": "",
80
- "transformers_version": "4.36.2",
81
  "aqlm": {
82
  "nbits_per_codebook": 16,
83
  "num_codebooks": 1,
 
24
  "tf_legacy_loss": false,
25
  "pruned_heads": {},
26
  "tie_word_embeddings": false,
27
+ "chunk_size_feed_forward": 0,
28
  "is_encoder_decoder": false,
29
  "is_decoder": false,
30
  "cross_attention_hidden_size": null,
 
47
  "encoder_no_repeat_ngram_size": 0,
48
  "bad_words_ids": null,
49
  "num_return_sequences": 1,
 
50
  "output_scores": false,
51
  "return_dict_in_generate": false,
52
  "forced_bos_token_id": null,
 
77
  "task_specific_params": null,
78
  "problem_type": null,
79
  "_name_or_path": "",
80
+ "transformers_version": "4.37.1",
81
  "aqlm": {
82
  "nbits_per_codebook": 16,
83
  "num_codebooks": 1,
modeling_llama_aqlm.py CHANGED
@@ -25,6 +25,7 @@ from typing import List, Optional, Tuple, Union
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
 
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
@@ -53,11 +54,13 @@ from transformers.utils import (
53
  from transformers.utils.import_utils import is_torch_fx_available
54
 
55
  from .configuration_llama_aqlm import LlamaConfig
56
- from aqlm import QuantizedLinear
57
 
58
  if is_flash_attn_2_available():
59
- from flash_attn import flash_attn_func, flash_attn_varlen_func
60
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
 
 
61
 
62
 
63
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
28
+ from aqlm import QuantizedLinear
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
  from transformers.activations import ACT2FN
 
54
  from transformers.utils.import_utils import is_torch_fx_available
55
 
56
  from .configuration_llama_aqlm import LlamaConfig
 
57
 
58
  if is_flash_attn_2_available():
59
+ try:
60
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
61
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
62
+ except:
63
+ pass
64
 
65
 
66
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.