Upload modeling_llama.py
Browse files- modeling_llama.py +128 -72
modeling_llama.py
CHANGED
@@ -17,7 +17,10 @@
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
|
|
|
|
20 |
import bs4
|
|
|
21 |
import math
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
|
@@ -32,7 +35,6 @@ from transformers.activations import ACT2FN
|
|
32 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
33 |
from transformers.generation import GenerationMixin
|
34 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
35 |
-
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
36 |
from transformers.modeling_outputs import (
|
37 |
BaseModelOutputWithPast,
|
38 |
CausalLMOutputWithPast,
|
@@ -50,6 +52,19 @@ from transformers.utils import (
|
|
50 |
logging,
|
51 |
replace_return_docstrings,
|
52 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
from .configuration_llama import LlamaConfig
|
54 |
from collections import defaultdict
|
55 |
from typing import List, Tuple
|
@@ -97,66 +112,6 @@ class TokenIdNode(Node):
|
|
97 |
self.input_ids = kwargs.get('input_ids', [])
|
98 |
self.prob = kwargs.get('prob', np.float32(0.0))
|
99 |
|
100 |
-
|
101 |
-
def split_tree(soup: bs4.BeautifulSoup, max_node_words=0) -> List[Tuple[bs4.element.Tag, List[str], bool]]:
|
102 |
-
word_count = len(soup.get_text().split())
|
103 |
-
if word_count > max_node_words:
|
104 |
-
possible_trees = [(soup, [])]
|
105 |
-
target_trees = [] # [(tag, path, is_leaf)]
|
106 |
-
# split the entire dom tee into subtrees, until the length of the subtree is less than max_node_words words
|
107 |
-
# find all possible trees
|
108 |
-
while True:
|
109 |
-
if len(possible_trees) == 0:
|
110 |
-
break
|
111 |
-
tree = possible_trees.pop(0)
|
112 |
-
tag_children = defaultdict(int)
|
113 |
-
bare_word_count = 0
|
114 |
-
# count child tags
|
115 |
-
for child in tree[0].contents:
|
116 |
-
if isinstance(child, bs4.element.Tag):
|
117 |
-
tag_children[child.name] += 1
|
118 |
-
_tag_children = {k: 0 for k in tag_children.keys()}
|
119 |
-
|
120 |
-
# check if the tree can be split
|
121 |
-
for child in tree[0].contents:
|
122 |
-
if isinstance(child, bs4.element.Tag):
|
123 |
-
# change child tag with duplicate names
|
124 |
-
if tag_children[child.name] > 1:
|
125 |
-
new_name = f"{child.name}{_tag_children[child.name]}"
|
126 |
-
new_tree = (child, tree[1] + [new_name])
|
127 |
-
_tag_children[child.name] += 1
|
128 |
-
child.name = new_name
|
129 |
-
else:
|
130 |
-
new_tree = (child, tree[1] + [child.name])
|
131 |
-
word_count = len(child.get_text().split())
|
132 |
-
# add node with more than max_node_words words, and recursion depth is less than 64
|
133 |
-
if word_count > max_node_words and len(new_tree[1]) < 64:
|
134 |
-
possible_trees.append(new_tree)
|
135 |
-
else:
|
136 |
-
target_trees.append((new_tree[0], new_tree[1], True))
|
137 |
-
else:
|
138 |
-
bare_word_count += len(str(child).split())
|
139 |
-
|
140 |
-
# add leaf node
|
141 |
-
if len(tag_children) == 0:
|
142 |
-
target_trees.append((tree[0], tree[1], True))
|
143 |
-
# add node with more than max_node_words bare words
|
144 |
-
elif bare_word_count > max_node_words:
|
145 |
-
target_trees.append((tree[0], tree[1], False))
|
146 |
-
else:
|
147 |
-
soup_children = [c for c in soup.contents if isinstance(c, bs4.element.Tag)]
|
148 |
-
if len(soup_children) == 1:
|
149 |
-
target_trees = [(soup_children[0], [soup_children[0].name], True)]
|
150 |
-
else:
|
151 |
-
# add an html tag to wrap all children
|
152 |
-
new_soup = bs4.BeautifulSoup("", 'html.parser')
|
153 |
-
new_tag = new_soup.new_tag("html")
|
154 |
-
new_soup.append(new_tag)
|
155 |
-
for child in soup_children:
|
156 |
-
new_tag.append(child)
|
157 |
-
target_trees = [(new_tag, ["html"], True)]
|
158 |
-
return target_trees
|
159 |
-
|
160 |
logger = logging.get_logger(__name__)
|
161 |
|
162 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
@@ -517,6 +472,107 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
517 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
518 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
def forward(
|
521 |
self,
|
522 |
hidden_states: torch.Tensor,
|
@@ -600,17 +656,16 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
600 |
key_states = key_states.to(target_dtype)
|
601 |
value_states = value_states.to(target_dtype)
|
602 |
|
603 |
-
|
|
|
604 |
query_states,
|
605 |
key_states,
|
606 |
value_states,
|
607 |
attention_mask,
|
608 |
q_len,
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
613 |
-
is_causal=self.is_causal,
|
614 |
)
|
615 |
|
616 |
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
@@ -1752,6 +1807,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
1752 |
tokenizer,
|
1753 |
query: List[str],
|
1754 |
htmls: List[List[str]],
|
|
|
1755 |
**kwargs):
|
1756 |
max_seq_length = kwargs.pop("max_seq_length", 131072)
|
1757 |
def apply_html_tree_template(query, htmls):
|
@@ -1787,11 +1843,11 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
1787 |
soup.append(bs4.BeautifulSoup(html, 'html.parser'))
|
1788 |
|
1789 |
token_id_paths = []
|
1790 |
-
|
1791 |
-
is_leaf = [p[2] for p in
|
1792 |
-
|
1793 |
|
1794 |
-
for path in
|
1795 |
path_str = "<" + "><".join(path) + ">"
|
1796 |
token_ids = tokenizer.encode(path_str, add_special_tokens=False)
|
1797 |
token_id_paths.append(token_ids)
|
@@ -1849,7 +1905,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
|
|
1849 |
|
1850 |
res_html_refs.append({
|
1851 |
"html": str(soup),
|
1852 |
-
"paths":
|
1853 |
"is_leaf": is_leaf,
|
1854 |
"path_token_ids": token_id_paths,
|
1855 |
"node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))
|
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
20 |
+
import inspect
|
21 |
+
|
22 |
import bs4
|
23 |
+
import loguru
|
24 |
import math
|
25 |
from typing import List, Optional, Tuple, Union
|
26 |
|
|
|
35 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
36 |
from transformers.generation import GenerationMixin
|
37 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
|
38 |
from transformers.modeling_outputs import (
|
39 |
BaseModelOutputWithPast,
|
40 |
CausalLMOutputWithPast,
|
|
|
52 |
logging,
|
53 |
replace_return_docstrings,
|
54 |
)
|
55 |
+
try:
|
56 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
57 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
58 |
+
|
59 |
+
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
60 |
+
except ImportError as error:
|
61 |
+
loguru.logger.warning(
|
62 |
+
f"`flash-attention` package not found, consider installing for better performance: {error}."
|
63 |
+
)
|
64 |
+
if not _flash_supports_window_size:
|
65 |
+
loguru.logger.warning(
|
66 |
+
"Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
|
67 |
+
)
|
68 |
from .configuration_llama import LlamaConfig
|
69 |
from collections import defaultdict
|
70 |
from typing import List, Tuple
|
|
|
112 |
self.input_ids = kwargs.get('input_ids', [])
|
113 |
self.prob = kwargs.get('prob', np.float32(0.0))
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
logger = logging.get_logger(__name__)
|
116 |
|
117 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
|
|
472 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
473 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
474 |
|
475 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
|
476 |
+
def _flash_attention_forward(
|
477 |
+
self,
|
478 |
+
query_states,
|
479 |
+
key_states,
|
480 |
+
value_states,
|
481 |
+
attention_mask,
|
482 |
+
query_length,
|
483 |
+
dropout=0.0,
|
484 |
+
softmax_scale=None,
|
485 |
+
use_sliding_windows=False,
|
486 |
+
):
|
487 |
+
"""
|
488 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
489 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
query_states (`torch.Tensor`):
|
493 |
+
Input query states to be passed to Flash Attention API
|
494 |
+
key_states (`torch.Tensor`):
|
495 |
+
Input key states to be passed to Flash Attention API
|
496 |
+
value_states (`torch.Tensor`):
|
497 |
+
Input value states to be passed to Flash Attention API
|
498 |
+
attention_mask (`torch.Tensor`):
|
499 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
500 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
501 |
+
dropout (`float`):
|
502 |
+
Attention dropout
|
503 |
+
softmax_scale (`float`, *optional*):
|
504 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
505 |
+
use_sliding_windows (`bool`, *optional*):
|
506 |
+
Whether to activate sliding window attention.
|
507 |
+
"""
|
508 |
+
if not self._flash_attn_uses_top_left_mask:
|
509 |
+
causal = self.is_causal
|
510 |
+
else:
|
511 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
512 |
+
causal = self.is_causal and query_length != 1
|
513 |
+
|
514 |
+
# Contains at least one padding token in the sequence
|
515 |
+
if attention_mask is not None:
|
516 |
+
batch_size = query_states.shape[0]
|
517 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
518 |
+
query_states, key_states, value_states, attention_mask, query_length
|
519 |
+
)
|
520 |
+
|
521 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
522 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
523 |
+
|
524 |
+
if not use_sliding_windows:
|
525 |
+
attn_output_unpad = flash_attn_varlen_func(
|
526 |
+
query_states,
|
527 |
+
key_states,
|
528 |
+
value_states,
|
529 |
+
cu_seqlens_q=cu_seqlens_q,
|
530 |
+
cu_seqlens_k=cu_seqlens_k,
|
531 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
532 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
533 |
+
dropout_p=dropout,
|
534 |
+
softmax_scale=softmax_scale,
|
535 |
+
causal=causal,
|
536 |
+
)
|
537 |
+
else:
|
538 |
+
attn_output_unpad = flash_attn_varlen_func(
|
539 |
+
query_states,
|
540 |
+
key_states,
|
541 |
+
value_states,
|
542 |
+
cu_seqlens_q=cu_seqlens_q,
|
543 |
+
cu_seqlens_k=cu_seqlens_k,
|
544 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
545 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
546 |
+
dropout_p=dropout,
|
547 |
+
softmax_scale=softmax_scale,
|
548 |
+
causal=causal,
|
549 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
550 |
+
)
|
551 |
+
|
552 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
553 |
+
else:
|
554 |
+
if not use_sliding_windows:
|
555 |
+
attn_output = flash_attn_func(
|
556 |
+
query_states,
|
557 |
+
key_states,
|
558 |
+
value_states,
|
559 |
+
dropout,
|
560 |
+
softmax_scale=softmax_scale,
|
561 |
+
causal=causal,
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
attn_output = flash_attn_func(
|
565 |
+
query_states,
|
566 |
+
key_states,
|
567 |
+
value_states,
|
568 |
+
dropout,
|
569 |
+
softmax_scale=softmax_scale,
|
570 |
+
causal=causal,
|
571 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
572 |
+
)
|
573 |
+
|
574 |
+
return attn_output
|
575 |
+
|
576 |
def forward(
|
577 |
self,
|
578 |
hidden_states: torch.Tensor,
|
|
|
656 |
key_states = key_states.to(target_dtype)
|
657 |
value_states = value_states.to(target_dtype)
|
658 |
|
659 |
+
|
660 |
+
attn_output = self._flash_attention_forward(
|
661 |
query_states,
|
662 |
key_states,
|
663 |
value_states,
|
664 |
attention_mask,
|
665 |
q_len,
|
666 |
+
dropout_rate,
|
667 |
+
None,
|
668 |
+
getattr(self, "sliding_window", None),
|
|
|
|
|
669 |
)
|
670 |
|
671 |
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
1807 |
tokenizer,
|
1808 |
query: List[str],
|
1809 |
htmls: List[List[str]],
|
1810 |
+
block_tree: List[Tuple],
|
1811 |
**kwargs):
|
1812 |
max_seq_length = kwargs.pop("max_seq_length", 131072)
|
1813 |
def apply_html_tree_template(query, htmls):
|
|
|
1843 |
soup.append(bs4.BeautifulSoup(html, 'html.parser'))
|
1844 |
|
1845 |
token_id_paths = []
|
1846 |
+
_block_tree = block_tree[idx]
|
1847 |
+
is_leaf = [p[2] for p in _block_tree]
|
1848 |
+
_block_tree = [p[1] for p in _block_tree]
|
1849 |
|
1850 |
+
for path in _block_tree:
|
1851 |
path_str = "<" + "><".join(path) + ">"
|
1852 |
token_ids = tokenizer.encode(path_str, add_special_tokens=False)
|
1853 |
token_id_paths.append(token_ids)
|
|
|
1905 |
|
1906 |
res_html_refs.append({
|
1907 |
"html": str(soup),
|
1908 |
+
"paths": _block_tree,
|
1909 |
"is_leaf": is_leaf,
|
1910 |
"path_token_ids": token_id_paths,
|
1911 |
"node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))
|