zstanjj commited on
Commit
b36d18c
1 Parent(s): 8172f3b

Upload modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +128 -72
modeling_llama.py CHANGED
@@ -17,7 +17,10 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
20
  import bs4
 
21
  import math
22
  from typing import List, Optional, Tuple, Union
23
 
@@ -32,7 +35,6 @@ from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
33
  from transformers.generation import GenerationMixin
34
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
35
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
36
  from transformers.modeling_outputs import (
37
  BaseModelOutputWithPast,
38
  CausalLMOutputWithPast,
@@ -50,6 +52,19 @@ from transformers.utils import (
50
  logging,
51
  replace_return_docstrings,
52
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  from .configuration_llama import LlamaConfig
54
  from collections import defaultdict
55
  from typing import List, Tuple
@@ -97,66 +112,6 @@ class TokenIdNode(Node):
97
  self.input_ids = kwargs.get('input_ids', [])
98
  self.prob = kwargs.get('prob', np.float32(0.0))
99
 
100
-
101
- def split_tree(soup: bs4.BeautifulSoup, max_node_words=0) -> List[Tuple[bs4.element.Tag, List[str], bool]]:
102
- word_count = len(soup.get_text().split())
103
- if word_count > max_node_words:
104
- possible_trees = [(soup, [])]
105
- target_trees = [] # [(tag, path, is_leaf)]
106
- # split the entire dom tee into subtrees, until the length of the subtree is less than max_node_words words
107
- # find all possible trees
108
- while True:
109
- if len(possible_trees) == 0:
110
- break
111
- tree = possible_trees.pop(0)
112
- tag_children = defaultdict(int)
113
- bare_word_count = 0
114
- # count child tags
115
- for child in tree[0].contents:
116
- if isinstance(child, bs4.element.Tag):
117
- tag_children[child.name] += 1
118
- _tag_children = {k: 0 for k in tag_children.keys()}
119
-
120
- # check if the tree can be split
121
- for child in tree[0].contents:
122
- if isinstance(child, bs4.element.Tag):
123
- # change child tag with duplicate names
124
- if tag_children[child.name] > 1:
125
- new_name = f"{child.name}{_tag_children[child.name]}"
126
- new_tree = (child, tree[1] + [new_name])
127
- _tag_children[child.name] += 1
128
- child.name = new_name
129
- else:
130
- new_tree = (child, tree[1] + [child.name])
131
- word_count = len(child.get_text().split())
132
- # add node with more than max_node_words words, and recursion depth is less than 64
133
- if word_count > max_node_words and len(new_tree[1]) < 64:
134
- possible_trees.append(new_tree)
135
- else:
136
- target_trees.append((new_tree[0], new_tree[1], True))
137
- else:
138
- bare_word_count += len(str(child).split())
139
-
140
- # add leaf node
141
- if len(tag_children) == 0:
142
- target_trees.append((tree[0], tree[1], True))
143
- # add node with more than max_node_words bare words
144
- elif bare_word_count > max_node_words:
145
- target_trees.append((tree[0], tree[1], False))
146
- else:
147
- soup_children = [c for c in soup.contents if isinstance(c, bs4.element.Tag)]
148
- if len(soup_children) == 1:
149
- target_trees = [(soup_children[0], [soup_children[0].name], True)]
150
- else:
151
- # add an html tag to wrap all children
152
- new_soup = bs4.BeautifulSoup("", 'html.parser')
153
- new_tag = new_soup.new_tag("html")
154
- new_soup.append(new_tag)
155
- for child in soup_children:
156
- new_tag.append(child)
157
- target_trees = [(new_tag, ["html"], True)]
158
- return target_trees
159
-
160
  logger = logging.get_logger(__name__)
161
 
162
  _CONFIG_FOR_DOC = "LlamaConfig"
@@ -517,6 +472,107 @@ class LlamaFlashAttention2(LlamaAttention):
517
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
518
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  def forward(
521
  self,
522
  hidden_states: torch.Tensor,
@@ -600,17 +656,16 @@ class LlamaFlashAttention2(LlamaAttention):
600
  key_states = key_states.to(target_dtype)
601
  value_states = value_states.to(target_dtype)
602
 
603
- attn_output = _flash_attention_forward(
 
604
  query_states,
605
  key_states,
606
  value_states,
607
  attention_mask,
608
  q_len,
609
- position_ids=position_ids,
610
- dropout=dropout_rate,
611
- sliding_window=getattr(self, "sliding_window", None),
612
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
613
- is_causal=self.is_causal,
614
  )
615
 
616
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -1752,6 +1807,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
1752
  tokenizer,
1753
  query: List[str],
1754
  htmls: List[List[str]],
 
1755
  **kwargs):
1756
  max_seq_length = kwargs.pop("max_seq_length", 131072)
1757
  def apply_html_tree_template(query, htmls):
@@ -1787,11 +1843,11 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
1787
  soup.append(bs4.BeautifulSoup(html, 'html.parser'))
1788
 
1789
  token_id_paths = []
1790
- html_chunk_paths = split_tree(soup, max_node_words=self.max_node_words)
1791
- is_leaf = [p[2] for p in html_chunk_paths]
1792
- html_chunk_paths = [p[1] for p in html_chunk_paths]
1793
 
1794
- for path in html_chunk_paths:
1795
  path_str = "<" + "><".join(path) + ">"
1796
  token_ids = tokenizer.encode(path_str, add_special_tokens=False)
1797
  token_id_paths.append(token_ids)
@@ -1849,7 +1905,7 @@ class LlamaForHTMLTreeGeneration(LlamaPreTrainedModel):
1849
 
1850
  res_html_refs.append({
1851
  "html": str(soup),
1852
- "paths": html_chunk_paths,
1853
  "is_leaf": is_leaf,
1854
  "path_token_ids": token_id_paths,
1855
  "node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+ import inspect
21
+
22
  import bs4
23
+ import loguru
24
  import math
25
  from typing import List, Optional, Tuple, Union
26
 
 
35
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
36
  from transformers.generation import GenerationMixin
37
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
38
  from transformers.modeling_outputs import (
39
  BaseModelOutputWithPast,
40
  CausalLMOutputWithPast,
 
52
  logging,
53
  replace_return_docstrings,
54
  )
55
+ try:
56
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
57
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
58
+
59
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
60
+ except ImportError as error:
61
+ loguru.logger.warning(
62
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
63
+ )
64
+ if not _flash_supports_window_size:
65
+ loguru.logger.warning(
66
+ "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
67
+ )
68
  from .configuration_llama import LlamaConfig
69
  from collections import defaultdict
70
  from typing import List, Tuple
 
112
  self.input_ids = kwargs.get('input_ids', [])
113
  self.prob = kwargs.get('prob', np.float32(0.0))
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  logger = logging.get_logger(__name__)
116
 
117
  _CONFIG_FOR_DOC = "LlamaConfig"
 
472
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
473
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
474
 
475
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
476
+ def _flash_attention_forward(
477
+ self,
478
+ query_states,
479
+ key_states,
480
+ value_states,
481
+ attention_mask,
482
+ query_length,
483
+ dropout=0.0,
484
+ softmax_scale=None,
485
+ use_sliding_windows=False,
486
+ ):
487
+ """
488
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
489
+ first unpad the input, then computes the attention scores and pad the final attention scores.
490
+
491
+ Args:
492
+ query_states (`torch.Tensor`):
493
+ Input query states to be passed to Flash Attention API
494
+ key_states (`torch.Tensor`):
495
+ Input key states to be passed to Flash Attention API
496
+ value_states (`torch.Tensor`):
497
+ Input value states to be passed to Flash Attention API
498
+ attention_mask (`torch.Tensor`):
499
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
500
+ position of padding tokens and 1 for the position of non-padding tokens.
501
+ dropout (`float`):
502
+ Attention dropout
503
+ softmax_scale (`float`, *optional*):
504
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
505
+ use_sliding_windows (`bool`, *optional*):
506
+ Whether to activate sliding window attention.
507
+ """
508
+ if not self._flash_attn_uses_top_left_mask:
509
+ causal = self.is_causal
510
+ else:
511
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
512
+ causal = self.is_causal and query_length != 1
513
+
514
+ # Contains at least one padding token in the sequence
515
+ if attention_mask is not None:
516
+ batch_size = query_states.shape[0]
517
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
518
+ query_states, key_states, value_states, attention_mask, query_length
519
+ )
520
+
521
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
522
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
523
+
524
+ if not use_sliding_windows:
525
+ attn_output_unpad = flash_attn_varlen_func(
526
+ query_states,
527
+ key_states,
528
+ value_states,
529
+ cu_seqlens_q=cu_seqlens_q,
530
+ cu_seqlens_k=cu_seqlens_k,
531
+ max_seqlen_q=max_seqlen_in_batch_q,
532
+ max_seqlen_k=max_seqlen_in_batch_k,
533
+ dropout_p=dropout,
534
+ softmax_scale=softmax_scale,
535
+ causal=causal,
536
+ )
537
+ else:
538
+ attn_output_unpad = flash_attn_varlen_func(
539
+ query_states,
540
+ key_states,
541
+ value_states,
542
+ cu_seqlens_q=cu_seqlens_q,
543
+ cu_seqlens_k=cu_seqlens_k,
544
+ max_seqlen_q=max_seqlen_in_batch_q,
545
+ max_seqlen_k=max_seqlen_in_batch_k,
546
+ dropout_p=dropout,
547
+ softmax_scale=softmax_scale,
548
+ causal=causal,
549
+ window_size=(self.config.sliding_window, self.config.sliding_window),
550
+ )
551
+
552
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
553
+ else:
554
+ if not use_sliding_windows:
555
+ attn_output = flash_attn_func(
556
+ query_states,
557
+ key_states,
558
+ value_states,
559
+ dropout,
560
+ softmax_scale=softmax_scale,
561
+ causal=causal,
562
+ )
563
+ else:
564
+ attn_output = flash_attn_func(
565
+ query_states,
566
+ key_states,
567
+ value_states,
568
+ dropout,
569
+ softmax_scale=softmax_scale,
570
+ causal=causal,
571
+ window_size=(self.config.sliding_window, self.config.sliding_window),
572
+ )
573
+
574
+ return attn_output
575
+
576
  def forward(
577
  self,
578
  hidden_states: torch.Tensor,
 
656
  key_states = key_states.to(target_dtype)
657
  value_states = value_states.to(target_dtype)
658
 
659
+
660
+ attn_output = self._flash_attention_forward(
661
  query_states,
662
  key_states,
663
  value_states,
664
  attention_mask,
665
  q_len,
666
+ dropout_rate,
667
+ None,
668
+ getattr(self, "sliding_window", None),
 
 
669
  )
670
 
671
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
 
1807
  tokenizer,
1808
  query: List[str],
1809
  htmls: List[List[str]],
1810
+ block_tree: List[Tuple],
1811
  **kwargs):
1812
  max_seq_length = kwargs.pop("max_seq_length", 131072)
1813
  def apply_html_tree_template(query, htmls):
 
1843
  soup.append(bs4.BeautifulSoup(html, 'html.parser'))
1844
 
1845
  token_id_paths = []
1846
+ _block_tree = block_tree[idx]
1847
+ is_leaf = [p[2] for p in _block_tree]
1848
+ _block_tree = [p[1] for p in _block_tree]
1849
 
1850
+ for path in _block_tree:
1851
  path_str = "<" + "><".join(path) + ">"
1852
  token_ids = tokenizer.encode(path_str, add_special_tokens=False)
1853
  token_id_paths.append(token_ids)
 
1905
 
1906
  res_html_refs.append({
1907
  "html": str(soup),
1908
+ "paths": _block_tree,
1909
  "is_leaf": is_leaf,
1910
  "path_token_ids": token_id_paths,
1911
  "node_tree": list(TokenDotExporter(root, nodenamefunc=nodenamefunc))