Zhiyu Wu commited on
Commit
01bc423
1 Parent(s): aaadf66

add attention mask; fix stop_str length (#26)

Browse files
pegasus/benchmark.yaml CHANGED
@@ -3,7 +3,7 @@
3
  # {{ gpu }} is defined in `hosts.yaml`, and will be filled in when Pegasus
4
  # determines the specific node and gpu the generated job command will run on.
5
  - command:
6
- - docker exec leaderboard{{ gpu }} python scripts/benchmark.py --input-file sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json --model-path {{ model }} --task {{ task }}
7
  model:
8
  - /data/leaderboard/weights/metaai/llama-7B
9
  - /data/leaderboard/weights/metaai/llama-13B
@@ -31,3 +31,9 @@
31
  - chat-concise
32
  - instruct
33
  - instruct-concise
 
 
 
 
 
 
 
3
  # {{ gpu }} is defined in `hosts.yaml`, and will be filled in when Pegasus
4
  # determines the specific node and gpu the generated job command will run on.
5
  - command:
6
+ - docker exec leaderboard{{ gpu }} python scripts/benchmark.py --input-file sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json --model-path {{ model }} --task {{ task }} --batch-size {{ batch_size }}
7
  model:
8
  - /data/leaderboard/weights/metaai/llama-7B
9
  - /data/leaderboard/weights/metaai/llama-13B
 
31
  - chat-concise
32
  - instruct
33
  - instruct-concise
34
+ batch_size:
35
+ - 1
36
+ - 2
37
+ - 4
38
+ - 8
39
+ - 16
scripts/benchmark.py CHANGED
@@ -104,7 +104,10 @@ def run_inference(
104
  temperature, repetition_penalty, top_p, top_k
105
  )
106
 
107
- input_ids = tokenizer(prompts, padding=True).input_ids
 
 
 
108
  output_ids = [[] for _ in range(batch_size)]
109
 
110
  if model.config.is_encoder_decoder:
@@ -113,10 +116,12 @@ def run_inference(
113
  max_src_len = context_len - max_new_tokens - 1
114
 
115
  input_ids = [input_id[-max_src_len:] for input_id in input_ids]
 
116
 
117
  if model.config.is_encoder_decoder:
118
  encoder_output = model.encoder(
119
- input_ids=torch.as_tensor(input_ids, device=device)
 
120
  )[0]
121
  start_ids = torch.as_tensor(
122
  [[model.generation_config.decoder_start_token_id] for _ in range(batch_size)],
@@ -126,6 +131,12 @@ def run_inference(
126
 
127
  past_key_values = out = None
128
  stopped = np.array(batch_size*[False])
 
 
 
 
 
 
129
  for i in range(max_new_tokens):
130
  if i == 0: # prefill
131
  if model.config.is_encoder_decoder:
@@ -136,7 +147,7 @@ def run_inference(
136
  )
137
  logits = model.lm_head(out[0])
138
  else:
139
- out = model(torch.as_tensor(input_ids, device=device), use_cache=True)
140
  logits = out.logits
141
  past_key_values = out.past_key_values
142
  else: # decoding
@@ -157,10 +168,17 @@ def run_inference(
157
  ),
158
  use_cache=True,
159
  past_key_values=past_key_values,
 
160
  )
161
  logits = out.logits
162
  past_key_values = out.past_key_values
163
 
 
 
 
 
 
 
164
  if logits_processor:
165
  if repetition_penalty > 1.0:
166
  tmp_output_ids = torch.as_tensor(output_ids, device=logits.device)
@@ -213,14 +231,15 @@ def run_inference(
213
  for each_stop in stop_str:
214
  pos_array = np.char.rfind(output_np, each_stop, rfind_start)
215
  find_stop = pos_array != -1
 
 
216
  else:
217
  raise ValueError("Invalid stop field type.")
218
 
219
  stop_str_indices = np.where(find_stop & ~stopped)[0]
220
  if stop_str_indices.size > 0:
221
  for j in stop_str_indices:
222
- # TODO: find a elegant way to figure out the size of stop_str, here just suppose stop_str has one token
223
- result[j].response_length = i
224
  result[j].output = output[j][:pos_array[j]]
225
  stopped[find_stop] = True
226
 
@@ -378,7 +397,7 @@ def main(
378
 
379
  for is_warmup, input_prompts in data_iter:
380
  # Construct the input prompt.
381
- for i in range(batch_size):
382
  conv = copy.deepcopy(conv_base)
383
  conv.append_message(conv.roles[0], input_prompts[i])
384
  conv.append_message(conv.roles[1], "")
 
104
  temperature, repetition_penalty, top_p, top_k
105
  )
106
 
107
+ prompts_encode = tokenizer(prompts, padding=True)
108
+ input_ids = prompts_encode.input_ids
109
+ attention_masks = prompts_encode.attention_mask
110
+
111
  output_ids = [[] for _ in range(batch_size)]
112
 
113
  if model.config.is_encoder_decoder:
 
116
  max_src_len = context_len - max_new_tokens - 1
117
 
118
  input_ids = [input_id[-max_src_len:] for input_id in input_ids]
119
+ attention_masks = torch.as_tensor([attention_mask[-max_src_len:] for attention_mask in attention_masks], device=device)
120
 
121
  if model.config.is_encoder_decoder:
122
  encoder_output = model.encoder(
123
+ input_ids=torch.as_tensor(input_ids, device=device),
124
+ attention_mask=attention_masks
125
  )[0]
126
  start_ids = torch.as_tensor(
127
  [[model.generation_config.decoder_start_token_id] for _ in range(batch_size)],
 
131
 
132
  past_key_values = out = None
133
  stopped = np.array(batch_size*[False])
134
+
135
+ # stop string length
136
+ stop_str_length = np.zeros(batch_size, dtype=int)
137
+ if stop_str and isinstance(stop_str, str):
138
+ stop_str_length[:] = len(tokenizer(stop_str).input_ids)
139
+
140
  for i in range(max_new_tokens):
141
  if i == 0: # prefill
142
  if model.config.is_encoder_decoder:
 
147
  )
148
  logits = model.lm_head(out[0])
149
  else:
150
+ out = model(torch.as_tensor(input_ids, device=device), use_cache=True, attention_mask=attention_masks)
151
  logits = out.logits
152
  past_key_values = out.past_key_values
153
  else: # decoding
 
168
  ),
169
  use_cache=True,
170
  past_key_values=past_key_values,
171
+ attention_mask=attention_masks,
172
  )
173
  logits = out.logits
174
  past_key_values = out.past_key_values
175
 
176
+ # update attention mask
177
+ attention_masks = torch.cat(
178
+ [attention_masks, torch.ones((batch_size, 1), device=device)],
179
+ dim=1
180
+ )
181
+
182
  if logits_processor:
183
  if repetition_penalty > 1.0:
184
  tmp_output_ids = torch.as_tensor(output_ids, device=logits.device)
 
231
  for each_stop in stop_str:
232
  pos_array = np.char.rfind(output_np, each_stop, rfind_start)
233
  find_stop = pos_array != -1
234
+ # update stop_str_length with each stop_str_length for each request
235
+ stop_str_length[find_stop] = len(tokenizer(each_stop).input_ids)
236
  else:
237
  raise ValueError("Invalid stop field type.")
238
 
239
  stop_str_indices = np.where(find_stop & ~stopped)[0]
240
  if stop_str_indices.size > 0:
241
  for j in stop_str_indices:
242
+ result[j].response_length = i+1-stop_str_length[j]
 
243
  result[j].output = output[j][:pos_array[j]]
244
  stopped[find_stop] = True
245
 
 
397
 
398
  for is_warmup, input_prompts in data_iter:
399
  # Construct the input prompt.
400
+ for i in range(len(input_prompts)):
401
  conv = copy.deepcopy(conv_base)
402
  conv.append_message(conv.roles[0], input_prompts[i])
403
  conv.append_message(conv.roles[1], "")
scripts/sort.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import tyro
3
+
4
+ def main(data_dir:str, out_file:str) -> None:
5
+
6
+ with open(data_dir, "r") as f:
7
+ data = json.load(f)
8
+
9
+ sorted_data = sorted(data, key=lambda x: len(x['conversations'][0]['value']), reverse=True)
10
+
11
+ with open(out_file, "w") as f:
12
+ json.dump(sorted_data, f, indent=4)
13
+
14
+ if __name__ == "__main__":
15
+ tyro.cli(main)
sharegpt/README.md CHANGED
@@ -27,6 +27,7 @@ python -m fastchat.data.sample --in sg_90k_part1_html_cleaned_lang_first.json --
27
  ```
28
 
29
  ## Sorted data
 
30
  ```
31
  python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
32
  ```
 
27
  ```
28
 
29
  ## Sorted data
30
+ We sort the requests by sequence length, placing the longest sequences first. This approach minimizes the amount of padding required and allows for early detection of out-of-memory.
31
  ```
32
  python sort.py --data-dir sg_90k_part1_html_cleaned_lang_first_sampled.json --out-file sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json
33
  ```
sharegpt/sg_90k_part1_html_cleaned_lang_first_sampled_sorted.json CHANGED
The diff for this file is too large to render. See raw diff