Florian valade commited on
Commit
bda5ea2
·
1 Parent(s): 950b367

refactor to use HF hub and better design

Browse files
Files changed (5) hide show
  1. .gitignore +1 -1
  2. app.py +108 -60
  3. requirements.txt +2 -1
  4. src/BranchyModel.py +0 -469
  5. src/utils.py +0 -57
.gitignore CHANGED
@@ -1 +1 @@
1
- model/*
 
1
+ __pycache__
app.py CHANGED
@@ -1,75 +1,123 @@
1
  # Save this as app.py and run with `streamlit run app.py`
 
2
  import streamlit as st
3
  import torch
4
  import pandas as pd
5
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
-
8
- from src.utils import generate_next_token, breaking_ties
9
- from src.BranchyModel import BranchyModel
10
 
11
  st.title("Multi-Head LLM Demo")
12
-
13
- def add_and_run(token, head):
14
- # Update pd with Head and mean of previous heads and actual head
15
- head_list = st.session_state["computation_pd"]["Head"].to_list() + [head]
16
- mean = sum(head_list) / len(head_list)
17
- st.session_state["computation_pd"] = pd.concat([st.session_state["computation_pd"], pd.DataFrame({"Head": [head], "Mean": [mean], "Base model consumption": [st.session_state['head_number']]})], ignore_index=True)
18
-
19
- st.session_state['current_sentence'] += token
20
- _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
21
-
22
- def reset():
23
- st.session_state['computation_pd'] = pd.DataFrame(columns=["Head", "Mean", "Base model consumption"])
24
- st.session_state['current_sentence'] = "The climate in"
25
- _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  @st.cache_resource
28
- def load_model(model_path):
29
-
30
- model_str = "susnato/phi-1_5_dev"
31
- model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
32
- tokenizer = AutoTokenizer.from_pretrained(model_str)
33
-
34
- branch_locations = list(range(0, 23, 5))
35
- model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
36
-
37
- # Load the specific model
38
- model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
39
-
40
  return model, tokenizer
41
 
 
 
42
 
43
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
44
  print("Loading model...")
45
- st.session_state.model, st.session_state.tokenizer = load_model("model/model.bin")
46
- st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
47
- print(f"Head number: {st.session_state['head_number']}")
48
- # Session state to store the current sentence
49
- if 'current_sentence' not in st.session_state:
50
- reset()
51
-
52
- # Create a container to hold the buttons
53
- cols = st.columns(len(st.session_state.head_tokens)) # Create a column for each token
54
-
55
- # Iterate through each head token and create a button in a separate column
56
- for i, (col, token) in enumerate(zip(cols, st.session_state.head_tokens)):
57
- col.button(f"{st.session_state['head_tokens'][i]}",
58
- key=f"head_{i}",
59
- use_container_width=True,
60
- on_click=add_and_run,
61
- args=(st.session_state['head_tokens'][i], i))
62
-
63
-
64
- # Display the current sentence
65
- st.markdown(f"{st.session_state['current_sentence']}")
66
-
67
- # Reset button to start over
68
- st.button('Reset', on_click=reset)
69
-
70
- if 'computation_pd' in st.session_state:
71
- st.line_chart(st.session_state['computation_pd'])
72
- # get last element from a pd
73
- saved_budget = 100 - ((st.session_state["computation_pd"]["Mean"].iloc[-1] * 100) / st.session_state["computation_pd"]["Base model consumption"].iloc[-1])
74
- st.markdown(f"You saved **{saved_budget:.2f}%** of the base model consumption.")
75
- #st.write(st.session_state['computation_pd'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Save this as app.py and run with `streamlit run app.py`
2
+ import time
3
  import streamlit as st
4
  import torch
5
  import pandas as pd
6
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from typer import clear
9
+ from annotated_text import annotated_text
 
10
 
11
  st.title("Multi-Head LLM Demo")
12
+ st.markdown("""This is a demo of a multi-head language model with early exit capabilities.
13
+ The model is based on the Phi-2 architecture and model is available here : https://huggingface.co/valcore/Branchy-Phi-2.
14
+ \nThe model has four heads, each of which can be exited early based on a threshold. The graph show the depth of early exit for each token (the deeper being the faster) and the time taken to generate each token.
15
+ Early exited tokens are annotated with the depth of early exit (with a float smaller than 1, 1 being the deepest)
16
+ """)
17
+
18
+ def annotated_to_normal(text):
19
+ result = ""
20
+ for elem in text:
21
+ if isinstance(elem, tuple):
22
+ result += elem[0]
23
+ else:
24
+ result += elem
25
+ return result
26
+
27
+ def generate_next_token():
28
+ print(f"Generating next token from {st.session_state.messages}")
29
+ inputs = ""
30
+ for message in st.session_state.messages:
31
+ inputs += message["role"] + ": " + annotated_to_normal(message["content"]) + "\n"
32
+ inputs += "Assistant:"
33
+ print(f"Inputs: {inputs}")
34
+ inputs = st.session_state.tokenizer.encode(inputs, return_tensors="pt")
35
+ for i in range(50):
36
+ start = time.time()
37
+ outputs = st.session_state.model(inputs)
38
+ stop = time.time()
39
+ next_token_logits = outputs.logits[:, -1, :].squeeze()
40
+ next_token_probs = torch.softmax(next_token_logits, dim=-1)
41
+ next_token_id = torch.argmax(next_token_probs, dim=-1)
42
+ if next_token_id == 50256:
43
+ break
44
+ print(inputs.shape, next_token_id.shape)
45
+ inputs = torch.cat([inputs, next_token_id.unsqueeze(0).unsqueeze(-1)], dim=-1)
46
+ next_token = st.session_state.tokenizer.decode(next_token_id, return_tensors="pt")
47
+ time_taken = stop - start
48
+ branch_locations = st.session_state.model.config.branch_locations
49
+ print(outputs.head_indices)
50
+ if outputs.head_indices in branch_locations:
51
+ print(sorted(branch_locations, reverse=True))
52
+ early_exit = (branch_locations.index(outputs.head_indices) + 1) / len(branch_locations)
53
+ else:
54
+ early_exit = 0
55
+ # Add data to dataframe
56
+ new_row = pd.DataFrame({"Time taken (in ms)": [time_taken], "Early exit depth": [early_exit]})
57
+ st.session_state.data = pd.concat([st.session_state.data, new_row], ignore_index=True)
58
+ yield next_token, early_exit
59
 
60
  @st.cache_resource
61
+ def load_model(model_str, tokenizer_str):
62
+ model = AutoModelForCausalLM.from_pretrained(model_str, trust_remote_code=True)
63
+ model.eval()
64
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_str)
 
 
 
 
 
 
 
 
65
  return model, tokenizer
66
 
67
+ model_str = "valcore/Branchy-Phi-2"
68
+ tokenizer_str = "microsoft/Phi-2"
69
 
70
  if "model" not in st.session_state or "tokenizer" not in st.session_state:
71
  print("Loading model...")
72
+ st.session_state.model, st.session_state.tokenizer = load_model(model_str, tokenizer_str)
73
+
74
+ # Initialize chat history and dataframe
75
+ if "messages" not in st.session_state:
76
+ st.session_state.messages = []
77
+ st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
78
+
79
+ col1, col2 = st.columns([1, 4])
80
+
81
+ with col1:
82
+ early_exit = st.checkbox("Early exit", value=False)
83
+ if early_exit:
84
+ st.session_state.model.head_thresholds = [2.506962537765503, 2.656052589416504, 1.924393653869629, 1.4434680938720703]
85
+ else:
86
+ st.session_state.model.head_thresholds = [10., 10., 10., 10.]
87
+ clear_session = st.button("Clear session")
88
+ if clear_session:
89
+ print("Clearing session")
90
+ st.session_state.messages = []
91
+ st.session_state.data = pd.DataFrame(columns=["Time taken (in ms)", "Early exit depth"])
92
+
93
+ with col2:
94
+ # Display chat messages from history on app rerun
95
+ for message in st.session_state.messages:
96
+ with st.chat_message(message["role"]):
97
+ annotated_text(message["content"])
98
+
99
+ prompt = st.chat_input("What is up?")
100
+ # React to user input
101
+ if prompt:
102
+ # Display user message in chat message container
103
+ with st.chat_message("User"):
104
+ st.markdown(prompt)
105
+ # Add user message to chat history
106
+ st.session_state.messages.append({"role": "User", "content": prompt})
107
+
108
+ # Display assistant response in chat message container
109
+ with st.chat_message("Assistant"):
110
+ response = []
111
+ with st.spinner('Running inference...'):
112
+ for next_token, early_exit in generate_next_token():
113
+ if early_exit > 0.0:
114
+ response.append(tuple((next_token, str(early_exit))))
115
+ else:
116
+ response.append(next_token)
117
+ print(response)
118
+ annotated_text(response)
119
+
120
+ # Add assistant response to chat history
121
+ st.session_state.messages.append({"role": "Assistant", "content": response})
122
+ st.line_chart(st.session_state.data, x=None, y=["Time taken (in ms)", "Early exit depth"])
123
+ print(st.session_state.messages)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit==1.31.0
2
  torch==2.0.1
3
  pandas==2.0.3
4
- transformers==4.36.0
 
 
1
  streamlit==1.31.0
2
  torch==2.0.1
3
  pandas==2.0.3
4
+ transformers==4.36.0
5
+ st-annotated-text
src/BranchyModel.py DELETED
@@ -1,469 +0,0 @@
1
- from typing import Dict, List, Optional
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
- from transformers import PreTrainedModel
8
- from transformers.cache_utils import Cache, DynamicCache
9
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
10
- from transformers.utils import ModelOutput
11
-
12
-
13
- @dataclass
14
- class CausalBranchyLLMOutputWithPast(ModelOutput):
15
- loss: Optional[torch.Tensor] = None
16
- lm_loss: Optional[torch.Tensor] = None
17
- head_loss: Optional[torch.Tensor] = None
18
- logits: torch.Tensor = None
19
- head_outputs: Optional[torch.Tensor] = None
20
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
21
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
- attentions: Optional[Tuple[torch.FloatTensor]] = None
23
-
24
- class Branch(nn.Module):
25
- def __init__(self, config):
26
- super().__init__()
27
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
28
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
29
-
30
- def forward(self, x):
31
- x = self.layernorm(x)
32
- x = self.lm_head(x)
33
- return x
34
-
35
- class BranchyModel(PreTrainedModel):
36
- """
37
- This class is a wrapper for transformer models with added functionality for branchy networks.
38
- It uses BranchyConfig to initialize a model and later will be extended to add branches.
39
-
40
- Args:
41
- branch_locations (List[int]): The locations of the branches in the model.
42
- starts indexing from 0. Branch 0 is after layer 0.
43
- model (PreTrainedModel): The underlying transformer model to wrap.
44
-
45
- Returns:
46
- A model instance with the given configuration.
47
- """
48
-
49
- def __init__(self, branch_locations, model, loss_type="kl_div", penality_weight=None):
50
- super().__init__(model.config)
51
- # Initialize the base transformer model
52
- self.model = model
53
- self.branch_locations = branch_locations
54
- self.loss_type = loss_type
55
- self.penality_weight = penality_weight
56
- if self.loss_type == "penalized_cross_entropy":
57
- assert self.penality_weight is not None, "penality_weight must be provided for penalized_cross_entropy loss"
58
- # Get details on layering inside the model
59
- if hasattr(self.model.config, "n_layer") or hasattr(
60
- self.model.config, "num_hidden_layers"
61
- ): # If there is no n_layer in the config, there might be ways to get it from the model itself
62
- self.num_layers = (
63
- self.model.config.n_layer
64
- if hasattr(self.model.config, "n_layer")
65
- else self.model.config.num_hidden_layers
66
- )
67
- else:
68
- raise ValueError("cannot find n_layer in config")
69
- # if no branch locations are specified, branch at every layer
70
- if self.branch_locations is None:
71
- self.branch_locations = list(range(self.num_layers - 1))
72
-
73
- assert self.num_layers > 0, "The number of layers must be greater than 0"
74
- assert (
75
- len(self.branch_locations) < self.num_layers
76
- ), "The number of branches must be less than the number of layers"
77
- assert all(
78
- [0 <= i < self.num_layers for i in self.branch_locations]
79
- ), "The branch locations must be between 0 and num_layers"
80
-
81
-
82
- # Make sure the base model is frozen
83
- for param in self.model.parameters():
84
- param.requires_grad = False
85
-
86
- # Instantiate heads. Default: heads are copies of the lm_head
87
- self.model.heads = torch.nn.ModuleList(
88
- [
89
- Branch(self.model.config) for _ in range(len(self.branch_locations))
90
- ]
91
- )
92
-
93
- # initialize heads
94
- for head in self.model.heads:
95
- head.apply(self.model._init_weights)
96
- # Make them trainable
97
- for param in head.parameters():
98
- param.requires_grad = True
99
-
100
- self.post_init()
101
-
102
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
103
- def prepare_inputs_for_generation(
104
- self,
105
- input_ids,
106
- past_key_values=None,
107
- attention_mask=None,
108
- inputs_embeds=None,
109
- **kwargs,
110
- ):
111
- if past_key_values is not None:
112
- if isinstance(past_key_values, Cache):
113
- cache_length = past_key_values.get_seq_length()
114
- past_length = past_key_values.seen_tokens
115
- max_cache_length = past_key_values.get_max_length()
116
- else:
117
- cache_length = past_length = past_key_values[0][0].shape[2]
118
- max_cache_length = None
119
-
120
- # Keep only the unprocessed tokens:
121
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
122
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
123
- # input)
124
- if (
125
- attention_mask is not None
126
- and attention_mask.shape[1] > input_ids.shape[1]
127
- ):
128
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
129
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
130
- # input_ids based on the past_length.
131
- elif past_length < input_ids.shape[1]:
132
- input_ids = input_ids[:, past_length:]
133
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
134
-
135
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
136
- if (
137
- max_cache_length is not None
138
- and attention_mask is not None
139
- and cache_length + input_ids.shape[1] > max_cache_length
140
- ):
141
- attention_mask = attention_mask[:, -max_cache_length:]
142
-
143
- position_ids = kwargs.get("position_ids", None)
144
- if attention_mask is not None and position_ids is None:
145
- # create position_ids on the fly for batch generation
146
- position_ids = attention_mask.long().cumsum(-1) - 1
147
- position_ids.masked_fill_(attention_mask == 0, 1)
148
- if past_key_values:
149
- position_ids = position_ids[:, -input_ids.shape[1] :]
150
-
151
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
152
- if inputs_embeds is not None and past_key_values is None:
153
- model_inputs = {"inputs_embeds": inputs_embeds}
154
- else:
155
- model_inputs = {"input_ids": input_ids}
156
-
157
- model_inputs.update(
158
- {
159
- "position_ids": position_ids,
160
- "past_key_values": past_key_values,
161
- "use_cache": kwargs.get("use_cache"),
162
- "attention_mask": attention_mask,
163
- "fixed_output_head": kwargs.get("fixed_output_head", None),
164
- }
165
- )
166
- return model_inputs
167
-
168
- def compute_self_supervision_loss(
169
- self,
170
- aux_logits: torch.Tensor,
171
- lm_logits: torch.Tensor,
172
- return_per_head: bool = False,
173
- ) -> Dict[str, torch.Tensor]:
174
- last_aux_logits = aux_logits[..., -1, :]
175
- last_lm_logits = lm_logits[..., -1, :]
176
-
177
- repeated_last_lm_logits = last_lm_logits.repeat(
178
- last_aux_logits.shape[0], 1, 1, 1
179
- )
180
- losses = []
181
- # Can be useful to have detailed loss per head for comparison of performance
182
- if return_per_head:
183
- for head_logit in last_aux_logits:
184
- if self.loss_type == "kl_div":
185
- losses.append(
186
- nn.KLDivLoss(reduction="batchmean")(
187
- F.log_softmax(head_logit, dim=-1),
188
- F.softmax(last_lm_logits, dim=-1),
189
- )
190
- )
191
- elif self.loss_type == "cross_entropy":
192
- losses.append(
193
- nn.CrossEntropyLoss(reduction="mean")(
194
- head_logit, torch.argmax(last_lm_logits, dim=-1)
195
- )
196
- )
197
- elif self.loss_type == "penalized_cross_entropy":
198
- ce_loss = nn.CrossEntropyLoss(reduction="mean")(
199
- head_logit, torch.argmax(last_lm_logits, dim=-1)
200
- )
201
- probas = F.softmax(head_logit, dim=-1)
202
- entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
203
- #losses.append(ce_loss - self.penality_weight * (1.0 / (1.0 + entropy)))
204
- losses.append(ce_loss - self.penality_weight * entropy)
205
- else:
206
- raise ValueError(
207
- "The loss type must be either kl_div or cross_entropy"
208
- )
209
- loss = torch.stack(losses, dim=0).mean(dim=-1)
210
- else:
211
- # Compute the KL divergence between the last auxiliary head and the last LM head
212
- if self.loss_type == "kl_div":
213
- loss = nn.KLDivLoss(reduction="batchmean")(
214
- F.log_softmax(last_aux_logits.view(-1, self.config.vocab_size), dim=-1),
215
- F.softmax(
216
- repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
217
- ),
218
- )
219
- elif self.loss_type == "cross_entropy":
220
- loss = nn.CrossEntropyLoss(reduction="mean")(
221
- last_aux_logits.view(-1, self.config.vocab_size),
222
- torch.argmax(
223
- repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
224
- ),
225
- )
226
- elif self.loss_type == "penalized_cross_entropy":
227
- ce_loss = nn.CrossEntropyLoss(reduction="mean")(
228
- last_aux_logits.view(-1, self.config.vocab_size),
229
- torch.argmax(
230
- repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
231
- ),
232
- )
233
- probas = F.softmax(
234
- last_aux_logits.view(-1, self.config.vocab_size), dim=-1
235
- )
236
- entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
237
- loss = ce_loss + self.penality_weight * entropy
238
- else:
239
- raise ValueError(
240
- "The loss type must be either kl_div or cross_entropy"
241
- )
242
- if return_per_head:
243
- return {"loss": loss, "aux_loss": torch.stack(losses)}
244
- else:
245
- return {"loss": loss, "aux_loss": None}
246
-
247
- def forward(
248
- self,
249
- input_ids: torch.LongTensor = None,
250
- attention_mask: Optional[torch.Tensor] = None,
251
- position_ids: Optional[torch.LongTensor] = None,
252
- past_key_values: Optional[List[torch.FloatTensor]] = None,
253
- inputs_embeds: Optional[torch.FloatTensor] = None,
254
- labels: Optional[torch.LongTensor] = None,
255
- use_cache: Optional[bool] = None,
256
- output_attentions: Optional[bool] = None,
257
- output_hidden_states: Optional[bool] = None,
258
- return_dict: Optional[bool] = None,
259
- self_supervision: Optional[bool] = None,
260
- fixed_output_head: Optional[int] = None,
261
- ):
262
- output_attentions = (
263
- output_attentions
264
- if output_attentions is not None
265
- else self.config.output_attentions
266
- )
267
- return_dict = (
268
- return_dict if return_dict is not None else self.config.use_return_dict
269
- )
270
- use_cache = use_cache if use_cache is not None else self.config.use_cache
271
-
272
- if self_supervision:
273
- output_hidden_states = True
274
- return self.forward_for_training(
275
- input_ids=input_ids,
276
- attention_mask=attention_mask,
277
- position_ids=position_ids,
278
- past_key_values=past_key_values,
279
- inputs_embeds=inputs_embeds,
280
- labels=labels,
281
- use_cache=use_cache,
282
- output_attentions=output_attentions,
283
- output_hidden_states=output_hidden_states,
284
- return_dict=return_dict,
285
- )
286
- else:
287
- return self.forward_for_inference(
288
- input_ids=input_ids,
289
- attention_mask=attention_mask,
290
- position_ids=position_ids,
291
- past_key_values=past_key_values,
292
- inputs_embeds=inputs_embeds,
293
- use_cache=use_cache,
294
- return_dict=return_dict,
295
- fixed_output_head=fixed_output_head,
296
- )
297
-
298
- def forward_for_inference(
299
- self,
300
- input_ids: torch.LongTensor = None,
301
- attention_mask: Optional[torch.Tensor] = None,
302
- position_ids: Optional[torch.LongTensor] = None,
303
- past_key_values: Optional[List[torch.FloatTensor]] = None,
304
- inputs_embeds: Optional[torch.FloatTensor] = None,
305
- use_cache: Optional[bool] = None,
306
- return_dict: Optional[bool] = None,
307
- fixed_output_head: Optional[int] = None,
308
- ):
309
- if fixed_output_head not in self.branch_locations and fixed_output_head is not None and fixed_output_head != -1:
310
- raise ValueError(
311
- "The fixed output head must be one of the branch locations"
312
- )
313
- # retrieve input_ids and inputs_embeds
314
- if input_ids is not None and inputs_embeds is not None:
315
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
316
- elif input_ids is not None:
317
- batch_size, seq_length = input_ids.shape
318
- elif inputs_embeds is not None:
319
- batch_size, seq_length, _ = inputs_embeds.shape
320
- else:
321
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
322
-
323
- past_key_values_length = 0
324
-
325
- if use_cache:
326
- use_legacy_cache = not isinstance(past_key_values, Cache)
327
- if use_legacy_cache:
328
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
329
- past_key_values_length = past_key_values.get_usable_length(seq_length)
330
-
331
- if position_ids is None:
332
- device = input_ids.device if input_ids is not None else inputs_embeds.device
333
- position_ids = torch.arange(
334
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
335
- )
336
- position_ids = position_ids.unsqueeze(0)
337
-
338
- if inputs_embeds is None:
339
- inputs_embeds = self.model.model.embed_tokens(input_ids)
340
-
341
- inputs_embeds = self.model.model.embed_dropout(inputs_embeds)
342
-
343
- # Attention mask.
344
- if self.model.model._use_flash_attention_2:
345
- # 2d mask is passed through the layers
346
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
347
- else:
348
- # 4d mask is passed through the layers
349
- attention_mask = _prepare_4d_causal_attention_mask(
350
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
351
- )
352
- all_head_logits = []
353
- hidden_states = inputs_embeds
354
- is_early_exited = False
355
- for layer_idx, decoder_layer in enumerate(self.model.model.layers):
356
- layer_outputs = decoder_layer(
357
- hidden_states,
358
- attention_mask=attention_mask,
359
- position_ids=position_ids,
360
- past_key_value=past_key_values,
361
- use_cache=use_cache,
362
- )
363
-
364
- hidden_states = layer_outputs[0]
365
-
366
- if use_cache:
367
- next_decoder_cache = layer_outputs[1]
368
-
369
- if fixed_output_head is not None and layer_idx == fixed_output_head:
370
- # find postion of layer idx in branch_locations
371
- branch_idx = self.branch_locations.index(layer_idx)
372
- logits = self.model.heads[branch_idx](hidden_states)
373
- is_early_exited = True
374
- break
375
- elif fixed_output_head == -1 and layer_idx in self.branch_locations:
376
- # -1 means output all heads
377
- branch_idx = self.branch_locations.index(layer_idx)
378
- logits = self.model.heads[branch_idx](hidden_states)
379
- all_head_logits.append(logits)
380
-
381
- if not is_early_exited:
382
- hidden_states = self.model.model.final_layernorm(hidden_states)
383
- logits = self.model.lm_head(hidden_states)
384
- if fixed_output_head == -1:
385
- all_head_logits.append(logits)
386
- all_head_logits = torch.stack(all_head_logits, dim=0)
387
- next_cache = None
388
- if use_cache:
389
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
390
- if not return_dict:
391
- return tuple(v for v in [logits, next_cache] if v is not None)
392
-
393
- return CausalBranchyLLMOutputWithPast(
394
- logits=logits,
395
- head_outputs=all_head_logits,
396
- past_key_values=next_cache,
397
- )
398
-
399
- def forward_for_training(
400
- self,
401
- input_ids: torch.LongTensor = None,
402
- attention_mask: Optional[torch.Tensor] = None,
403
- position_ids: Optional[torch.LongTensor] = None,
404
- past_key_values: Optional[List[torch.FloatTensor]] = None,
405
- inputs_embeds: Optional[torch.FloatTensor] = None,
406
- labels: Optional[torch.LongTensor] = None,
407
- use_cache: Optional[bool] = None,
408
- output_attentions: Optional[bool] = None,
409
- output_hidden_states: Optional[bool] = None,
410
- return_dict: Optional[bool] = None,
411
- ):
412
-
413
- if not output_hidden_states:
414
- raise ValueError("output_hidden_states must be True for BranchyLLM")
415
- if labels is not None:
416
- raise NotImplementedError("BranchyLLM only supports self-supervision")
417
- outputs = self.model(
418
- input_ids=input_ids,
419
- attention_mask=attention_mask,
420
- position_ids=position_ids,
421
- past_key_values=past_key_values,
422
- inputs_embeds=inputs_embeds,
423
- use_cache=use_cache,
424
- output_attentions=output_attentions,
425
- output_hidden_states=output_hidden_states,
426
- return_dict=return_dict,
427
- )
428
- if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
429
- raise ValueError("The model must return hidden states")
430
- hidden_states = outputs.hidden_states
431
-
432
-
433
- heads_logits = []
434
- for i, branch in enumerate(self.branch_locations):
435
- heads_logits.append(
436
- self.model.heads[i](
437
- hidden_states[branch]
438
- )
439
- )
440
- lm_logits = self.model.lm_head(hidden_states[-1])
441
-
442
- heads_logits = torch.stack(heads_logits, dim=0).float()
443
- lm_logits = lm_logits.float()
444
- logits = torch.cat([heads_logits, lm_logits.unsqueeze(0)], dim=0)
445
-
446
- loss = None
447
- lm_loss = None
448
- aux_loss = None
449
-
450
- losses = self.compute_self_supervision_loss(
451
- heads_logits, lm_logits, return_per_head=True
452
- )
453
- loss = losses["loss"]
454
- if losses["aux_loss"] is not None:
455
- aux_loss = losses["aux_loss"]
456
-
457
- if not return_dict:
458
- output = (logits,) + outputs[1:]
459
- return ((loss, aux_loss, lm_loss) + output) if loss is not None else output
460
-
461
- return CausalBranchyLLMOutputWithPast(
462
- loss=loss,
463
- lm_loss=lm_loss,
464
- head_loss=aux_loss,
465
- logits=logits,
466
- past_key_values=outputs.past_key_values,
467
- hidden_states=outputs.hidden_states,
468
- attentions=outputs.attentions,
469
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,57 +0,0 @@
1
- import torch
2
-
3
- def generate_next_token(model, tokenizer, input, method='greedy'):
4
- """
5
- Generate the next token of a sequence using the given model and tokenizer.
6
- Specific for multi branched models.
7
- Only output token from last head.
8
-
9
- Args:
10
- model (torch.nn.Module): The model to use for generation.
11
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for generation.
12
- input (str): The input text to generate from.
13
-
14
- Returns:
15
- token (str): The next token in the sequence.
16
- logits (torch.Tensor): The logits of the next token. of shape[Head, vocab_size]
17
- new_sequence (str): The new sequence after adding the next token.
18
- """
19
- device = model.device
20
- input_ids = tokenizer.encode(input, return_tensors="pt").to(device)
21
- model.eval()
22
- logits = model(input_ids, fixed_output_head=-1).head_outputs[..., -1, :].squeeze(1) # squeeze batch dimension as it is 1 new shape is (head_count, vocab_size)
23
- if logits == []:
24
- raise ValueError("Model does not have head_outputs")
25
- if method == 'greedy':
26
- head_tokens = torch.argmax(logits, dim=-1)
27
- elif method == 'sample':
28
- head_tokens = torch.multinomial(torch.nn.functional.softmax(logits, dim=-1), num_samples=1)
29
- elif method == 'top_k':
30
- k = 5
31
- top_k = torch.topk(logits, k, dim=-1)
32
- top_k_logits, top_k_indices = top_k.values, top_k.indices
33
- top_k_probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
34
- head_tokens = top_k_indices[torch.arange(top_k_probs.shape[0]), torch.multinomial(top_k_probs, num_samples=1).squeeze()]
35
- elif method == 'top_p':
36
- # logits is of shape [batch, vocab_size]
37
- p = 0.9
38
- probs = torch.nn.functional.softmax(logits, dim=-1)
39
- sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
40
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
41
- sorted_indices_to_remove = cumulative_probs > p
42
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
- sorted_indices_to_remove[..., 0] = 0
44
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
45
- tmp_logits = logits.clone()
46
- for i in range(logits.shape[0]):
47
- tmp_logits[i, indices_to_remove[i]] = float('-inf')
48
- head_tokens = torch.multinomial(torch.nn.functional.softmax(tmp_logits, dim=-1), num_samples=1).squeeze()
49
- else:
50
- raise ValueError(f"Unknown method: {method}")
51
- head_tokens = tokenizer.batch_decode(head_tokens) # Treat head dim as batch dim
52
- new_sequence = input + head_tokens[-1]
53
- return head_tokens[-1], logits, new_sequence, head_tokens
54
-
55
-
56
- def breaking_ties(tensor):
57
- return torch.sub(torch.topk(tensor, 2, dim=-1).values[..., 0], torch.topk(tensor, 2, dim=-1).values[..., 1]).squeeze()