Andrei Panferov
commited on
Commit
•
f48478c
1
Parent(s):
03ea233
try except flash-attn
Browse files- config.json +2 -2
- modeling_llama_aqlm.py +6 -3
config.json
CHANGED
@@ -24,6 +24,7 @@
|
|
24 |
"tf_legacy_loss": false,
|
25 |
"pruned_heads": {},
|
26 |
"tie_word_embeddings": false,
|
|
|
27 |
"is_encoder_decoder": false,
|
28 |
"is_decoder": false,
|
29 |
"cross_attention_hidden_size": null,
|
@@ -46,7 +47,6 @@
|
|
46 |
"encoder_no_repeat_ngram_size": 0,
|
47 |
"bad_words_ids": null,
|
48 |
"num_return_sequences": 1,
|
49 |
-
"chunk_size_feed_forward": 0,
|
50 |
"output_scores": false,
|
51 |
"return_dict_in_generate": false,
|
52 |
"forced_bos_token_id": null,
|
@@ -77,7 +77,7 @@
|
|
77 |
"task_specific_params": null,
|
78 |
"problem_type": null,
|
79 |
"_name_or_path": "",
|
80 |
-
"transformers_version": "4.
|
81 |
"aqlm": {
|
82 |
"nbits_per_codebook": 16,
|
83 |
"num_codebooks": 1,
|
|
|
24 |
"tf_legacy_loss": false,
|
25 |
"pruned_heads": {},
|
26 |
"tie_word_embeddings": false,
|
27 |
+
"chunk_size_feed_forward": 0,
|
28 |
"is_encoder_decoder": false,
|
29 |
"is_decoder": false,
|
30 |
"cross_attention_hidden_size": null,
|
|
|
47 |
"encoder_no_repeat_ngram_size": 0,
|
48 |
"bad_words_ids": null,
|
49 |
"num_return_sequences": 1,
|
|
|
50 |
"output_scores": false,
|
51 |
"return_dict_in_generate": false,
|
52 |
"forced_bos_token_id": null,
|
|
|
77 |
"task_specific_params": null,
|
78 |
"problem_type": null,
|
79 |
"_name_or_path": "",
|
80 |
+
"transformers_version": "4.37.1",
|
81 |
"aqlm": {
|
82 |
"nbits_per_codebook": 16,
|
83 |
"num_codebooks": 1,
|
modeling_llama_aqlm.py
CHANGED
@@ -25,6 +25,7 @@ from typing import List, Optional, Tuple, Union
|
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
|
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
30 |
from transformers.activations import ACT2FN
|
@@ -53,11 +54,13 @@ from transformers.utils import (
|
|
53 |
from transformers.utils.import_utils import is_torch_fx_available
|
54 |
|
55 |
from .configuration_llama_aqlm import LlamaConfig
|
56 |
-
from aqlm import QuantizedLinear
|
57 |
|
58 |
if is_flash_attn_2_available():
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
28 |
+
from aqlm import QuantizedLinear
|
29 |
from torch import nn
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
from transformers.activations import ACT2FN
|
|
|
54 |
from transformers.utils.import_utils import is_torch_fx_available
|
55 |
|
56 |
from .configuration_llama_aqlm import LlamaConfig
|
|
|
57 |
|
58 |
if is_flash_attn_2_available():
|
59 |
+
try:
|
60 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
61 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
62 |
+
except:
|
63 |
+
pass
|
64 |
|
65 |
|
66 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|