Zimix commited on
Commit
a5fe767
1 Parent(s): acb4385
Files changed (4) hide show
  1. app.py +31 -0
  2. interaction.py +146 -0
  3. requirements.txt +2 -0
  4. utils.py +654 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import gradio as gr
3
+ import uvicorn
4
+ import socket
5
+ from interaction import MindBot
6
+
7
+ CUSTOM_PATH = "/mindbot/541832"
8
+
9
+ # @app.get("/")
10
+ # def read_main():
11
+ # return {"message": "This is your main app"}
12
+
13
+
14
+ mind_bot = MindBot(
15
+ "/cognitive_comp/songchao/mindbot_demo/checkpoint",
16
+ "/cognitive_comp/songchao/checkpoints/13B-c-pretrain-tokenizer",
17
+ if_int8=True
18
+ )
19
+
20
+
21
+ # @app.get("/api/mindbot/541832")
22
+ # async def chat(query, clear_history=False):
23
+ # output = mind_bot.common_generate(query, clear_history)
24
+ # return output
25
+ # host = socket.gethostbyname(socket.gethostname())
26
+ # print(f'demo run on {host}')
27
+ demo = mind_bot.new_chat_bot()
28
+ demo.launch(share=True)
29
+
30
+ # app = gr.mount_gradio_app(app, demo, path=CUSTOM_PATH)
31
+ # uvicorn.run(app, host='192.168.81.9', port=7880)
interaction.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import torch.nn as nn
5
+ import argparse
6
+ import gradio as gr
7
+
8
+ from transformers import AutoTokenizer, LlamaForCausalLM
9
+ from utils import SteamGenerationMixin
10
+
11
+
12
+ class MindBot(object):
13
+ def __init__(self, model_path, tokenizer_path,if_int8=False):
14
+ # self.device = torch.device("cuda")
15
+ # device_ids = [1, 2]
16
+ if if_int8:
17
+ self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto', load_in_8bit=True).eval()
18
+ else:
19
+ self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto').half().eval()
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
22
+ # sp_tokens = {'additional_special_tokens': ['<human>', '<bot>']}
23
+ # self.tokenizer.add_special_tokens(sp_tokens)
24
+ self.history = []
25
+
26
+ def build_prompt(self, instruction, history, human='<human>', bot='<bot>'):
27
+ pmt = ''
28
+ if len(history) > 0:
29
+ for line in history:
30
+ pmt += f'{human}: {line[0].strip()}\n{bot}: {line[1]}\n'
31
+ pmt += f'{human}: {instruction.strip()}\n{bot}: \n'
32
+ return pmt
33
+
34
+ def common_generate(self, instruction, clear_history=False, max_memory=1024):
35
+ if clear_history:
36
+ self.history = []
37
+
38
+ prompt = self.build_prompt(instruction, self.history)
39
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
40
+ if input_ids.shape[1] > max_memory:
41
+ input_ids = input_ids[:, -max_memory:]
42
+
43
+ prompt_len = input_ids.shape[1]
44
+ # common method
45
+ generation_output = self.model.generate(
46
+ input_ids.cuda(),
47
+ max_new_tokens=1024,
48
+ do_sample=True,
49
+ top_p=0.85,
50
+ temperature=0.8,
51
+ repetition_penalty=1.,
52
+ eos_token_id=2,
53
+ bos_token_id=1,
54
+ pad_token_id=0
55
+ )
56
+
57
+ s = generation_output[0][prompt_len:]
58
+ output = self.tokenizer.decode(s, skip_special_tokens=True)
59
+ # output = output
60
+ output = output.replace("Belle", "IDEA")
61
+ self.history.append((instruction, output))
62
+ print('api history: ======> \n', self.history)
63
+
64
+ return output
65
+
66
+
67
+ def interaction(
68
+ self,
69
+ instruction,
70
+ history,
71
+ max_memory=1024
72
+ ):
73
+
74
+ prompt = self.build_prompt(instruction, history)
75
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
76
+ if input_ids.shape[1] > max_memory:
77
+ input_ids = input_ids[:, -max_memory:]
78
+
79
+ prompt_len = input_ids.shape[1]
80
+ # stream generation method
81
+ try:
82
+ tmp = history.copy()
83
+ output = ''
84
+ with torch.no_grad():
85
+ for generation_output in self.model.stream_generate(
86
+ input_ids.cuda(),
87
+ max_new_tokens=1024,
88
+ do_sample=True,
89
+ top_p=0.85,
90
+ temperature=0.8,
91
+ repetition_penalty=1.,
92
+ eos_token_id=2,
93
+ bos_token_id=1,
94
+ pad_token_id=0
95
+ ):
96
+ s = generation_output[0][prompt_len:]
97
+ output = self.tokenizer.decode(s, skip_special_tokens=True)
98
+ output = output.replace('\n', '<br>')
99
+ tmp.append((instruction, output))
100
+ yield '', tmp
101
+ tmp.pop()
102
+ # gc.collect()
103
+ # torch.cuda.empty_cache()
104
+ history.append((instruction, output))
105
+ print('input -----> \n', prompt)
106
+ print('output -------> \n', output)
107
+ print('history: ======> \n', history)
108
+ except torch.cuda.OutOfMemoryError:
109
+ gc.collect()
110
+ torch.cuda.empty_cache()
111
+ self.model.empty_cache()
112
+ return "", history
113
+
114
+ def new_chat_bot(self):
115
+
116
+ with gr.Blocks(title='IDEA MindBot', css=".gradio-container {max-width: 50% !important;} .bgcolor {color: white !important; background: #FFA500 !important;}") as demo:
117
+ gr.Markdown("<center><h1>IDEA MindBot</h1></center>")
118
+ gr.Markdown("<center>本页面基于hugging face支持的设备搭建</center>")
119
+ with gr.Row():
120
+ chatbot = gr.Chatbot(label='MindBot').style(height=500)
121
+ with gr.Row():
122
+ msg = gr.Textbox(label="Input")
123
+ with gr.Row():
124
+ with gr.Column(scale=0.5):
125
+ clear = gr.Button("Clear")
126
+ with gr.Column(scale=0.5):
127
+ submit = gr.Button("Submit", elem_classes='bgcolor')
128
+
129
+ msg.submit(self.interaction, [msg, chatbot], [msg, chatbot])
130
+ clear.click(lambda: None, None, chatbot, queue=False)
131
+ submit.click(self.interaction, [msg, chatbot], [msg, chatbot])
132
+ return demo.queue(concurrency_count=5)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument(
138
+ "--model_path",
139
+ type=str,
140
+ default="/cognitive_comp/songchao/checkpoints/global_step3200-hf"
141
+ )
142
+ args = parser.parse_args()
143
+
144
+ mind_bot = MindBot(args.model_path)
145
+ demo = mind_bot.new_chat_bot()
146
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch==1.12.1
2
+ transformers==4.28.1
utils.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple, Union, List, Callable
3
+ from transformers.generation.logits_process import LogitsProcessor
4
+ from transformers.generation.beam_search import BeamSearchScorer
5
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
6
+ from transformers.generation.utils import (
7
+ LogitsProcessorList,
8
+ StoppingCriteriaList,
9
+ GenerationConfig,
10
+ GenerationMixin,
11
+ )
12
+ from transformers import LlamaForCausalLM
13
+ import warnings
14
+ import torch.distributed as dist
15
+ from torch import nn
16
+ import copy
17
+
18
+
19
+ class SteamGenerationMixin(LlamaForCausalLM):
20
+ # support for streamly generation
21
+ # TODO: group_beam_search
22
+ @torch.no_grad()
23
+ def stream_generate(
24
+ self,
25
+ input_ids: Optional[torch.Tensor] = None,
26
+ generation_config: Optional[GenerationConfig] = None,
27
+ logits_processor: Optional[LogitsProcessorList] = None,
28
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
29
+ prefix_allowed_tokens_fn: Optional[
30
+ Callable[[int, torch.Tensor], List[int]]
31
+ ] = None,
32
+ **kwargs,
33
+ ):
34
+ self._reorder_cache = self.base_model._reorder_cache
35
+ if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
36
+ synced_gpus = True
37
+ else:
38
+ synced_gpus = False
39
+
40
+ if kwargs.get("attention_mask", None) is not None:
41
+ # concat prompt attention mask
42
+ prefix_attention_mask = torch.ones(
43
+ kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
44
+ ).to(kwargs["input_ids"].device)
45
+ kwargs["attention_mask"] = torch.cat(
46
+ (prefix_attention_mask, kwargs["attention_mask"]), dim=1
47
+ )
48
+ if kwargs.get("position_ids", None) is not None:
49
+ warnings.warn(
50
+ "Position ids are not supported for parameter efficient tuning. Ignoring position ids."
51
+ )
52
+ kwargs["position_ids"] = None
53
+ if kwargs.get("token_type_ids", None) is not None:
54
+ warnings.warn(
55
+ "Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
56
+ )
57
+ kwargs["token_type_ids"] = None
58
+
59
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
60
+
61
+ if generation_config is None:
62
+ generation_config = self.generation_config
63
+ generation_config = copy.deepcopy(generation_config)
64
+ model_kwargs = generation_config.update(**kwargs)
65
+
66
+ bos_token_id, eos_token_id, pad_token_id = (
67
+ generation_config.bos_token_id,
68
+ generation_config.eos_token_id,
69
+ generation_config.pad_token_id,
70
+ )
71
+
72
+ if isinstance(eos_token_id, int):
73
+ eos_token_id = [eos_token_id]
74
+
75
+ has_default_max_length = (
76
+ kwargs.get("max_length") is None
77
+ and generation_config.max_length is not None
78
+ )
79
+ if has_default_max_length and generation_config.max_new_tokens is None:
80
+ warnings.warn(
81
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
82
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
83
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
84
+ UserWarning,
85
+ )
86
+ elif generation_config.max_new_tokens is not None:
87
+ generation_config.max_length = (
88
+ generation_config.max_new_tokens + input_ids_seq_length
89
+ )
90
+ if generation_config.min_new_tokens is not None:
91
+ generation_config.min_length = (
92
+ generation_config.min_new_tokens + input_ids_seq_length
93
+ )
94
+
95
+ if input_ids_seq_length >= generation_config.max_length:
96
+ input_ids_string = (
97
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
98
+ )
99
+
100
+ # 2. Set generation parameters if not already defined
101
+ logits_processor = (
102
+ logits_processor if logits_processor is not None else LogitsProcessorList()
103
+ )
104
+ stopping_criteria = (
105
+ stopping_criteria
106
+ if stopping_criteria is not None
107
+ else StoppingCriteriaList()
108
+ )
109
+ # 7. determine generation mode
110
+ is_constraint_gen_mode = (
111
+ generation_config.constraints is not None or generation_config.force_words_ids is not None
112
+ )
113
+
114
+ is_contrastive_search_gen_mode = (
115
+ generation_config.top_k is not None
116
+ and generation_config.top_k > 1
117
+ and generation_config.do_sample is False
118
+ and generation_config.penalty_alpha is not None
119
+ and generation_config.penalty_alpha > 0
120
+ )
121
+
122
+ is_greedy_gen_mode = (
123
+ (generation_config.num_beams == 1)
124
+ and (generation_config.num_beam_groups == 1)
125
+ and generation_config.do_sample is False
126
+ and not is_constraint_gen_mode
127
+ and not is_contrastive_search_gen_mode
128
+ )
129
+ # beam=1 and do_sample=True
130
+ is_sample_gen_mode = (
131
+ (generation_config.num_beams == 1)
132
+ and (generation_config.num_beam_groups == 1)
133
+ and generation_config.do_sample is True
134
+ and not is_constraint_gen_mode
135
+ and not is_contrastive_search_gen_mode
136
+ )
137
+ is_beam_gen_mode = (
138
+ (generation_config.num_beams > 1)
139
+ and (generation_config.num_beam_groups == 1)
140
+ and generation_config.do_sample is False
141
+ and not is_constraint_gen_mode
142
+ and not is_contrastive_search_gen_mode
143
+ )
144
+ is_beam_sample_gen_mode = (
145
+ (generation_config.num_beams > 1)
146
+ and (generation_config.num_beam_groups == 1)
147
+ and generation_config.do_sample is True
148
+ and not is_constraint_gen_mode
149
+ and not is_contrastive_search_gen_mode
150
+ )
151
+ is_group_beam_gen_mode = (
152
+ (generation_config.num_beams > 1)
153
+ and (generation_config.num_beam_groups > 1)
154
+ and not is_constraint_gen_mode
155
+ and not is_contrastive_search_gen_mode
156
+ )
157
+ # 8. prepare distribution pre_processing samplers
158
+ logits_processor = self._get_logits_processor(
159
+ generation_config=generation_config,
160
+ input_ids_seq_length=input_ids_seq_length,
161
+ encoder_input_ids=input_ids,
162
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
163
+ logits_processor=logits_processor,
164
+ )
165
+ # 9. prepare stopping criteria
166
+ stopping_criteria = self._get_stopping_criteria(
167
+ generation_config=generation_config, stopping_criteria=stopping_criteria
168
+ )
169
+ logits_warper = self._get_logits_warper(generation_config)
170
+
171
+ if is_greedy_gen_mode:
172
+ # 11. run greedy search
173
+ return self.greedy_search(
174
+ input_ids,
175
+ logits_processor,
176
+ stopping_criteria,
177
+ generation_config,
178
+ synced_gpus,
179
+ **model_kwargs,
180
+ )
181
+ elif is_sample_gen_mode:
182
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
183
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
184
+ input_ids=input_ids,
185
+ expand_size=generation_config.num_return_sequences,
186
+ is_encoder_decoder=self.config.is_encoder_decoder,
187
+ **model_kwargs,
188
+ )
189
+ return self.stream_sample(
190
+ generation_config,
191
+ input_ids,
192
+ logits_processor,
193
+ logits_warper,
194
+ stopping_criteria,
195
+ synced_gpus,
196
+ **model_kwargs,
197
+ )
198
+ elif is_beam_gen_mode:
199
+ return self.beam_search(
200
+ generation_config,
201
+ input_ids,
202
+ logits_processor,
203
+ stopping_criteria,
204
+ synced_gpus,
205
+ **model_kwargs,
206
+ )
207
+ elif is_beam_sample_gen_mode:
208
+ # interleave input_ids with `num_beams` additional sequences per batch
209
+ return self.beam_sample(
210
+ input_ids,
211
+ logits_processor,
212
+ logits_warper,
213
+ stopping_criteria,
214
+ generation_config,
215
+ synced_gpus,
216
+ **model_kwargs,
217
+ )
218
+ else:
219
+ raise Exception('not implement')
220
+
221
+ def stream_sample(
222
+ self,
223
+ generation_config,
224
+ input_ids,
225
+ logits_processor,
226
+ logits_warper,
227
+ stopping_criteria,
228
+ synced_gpus,
229
+ **model_kwargs,
230
+ ):
231
+ bos_token_id, eos_token_id, pad_token_id = (
232
+ generation_config.bos_token_id,
233
+ generation_config.eos_token_id,
234
+ generation_config.pad_token_id,
235
+ )
236
+ if isinstance(eos_token_id, int):
237
+ eos_token_id = [eos_token_id]
238
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
239
+ # keep track of which sequences are already finished
240
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
241
+ this_peer_finished = False # used by synced_gpus only
242
+ scores=()
243
+ # auto-regressive generation
244
+ while True:
245
+ if synced_gpus:
246
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
247
+ # The following logic allows an early break if all peers finished generating their sequence
248
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
249
+ # send 0.0 if we finished, 1.0 otherwise
250
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
251
+ # did all peers finish? the reduced sum will be 0.0 then
252
+ if this_peer_finished_flag.item() == 0.0:
253
+ break
254
+ # prepare model inputs
255
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
256
+ # forward pass to get next token
257
+ outputs = self(
258
+ **model_inputs,
259
+ return_dict=True,
260
+ )
261
+ if synced_gpus and this_peer_finished:
262
+ continue # don't waste resources running the code we don't need
263
+ next_token_logits = outputs.logits[:, -1, :]
264
+ # pre-process distribution
265
+ next_token_scores = logits_processor(input_ids, next_token_logits)
266
+ next_token_scores = logits_warper(input_ids, next_token_scores)
267
+
268
+ # sample
269
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
270
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
271
+
272
+ # finished sentences should have their next token be a padding token
273
+ if eos_token_id is not None:
274
+ if pad_token_id is None:
275
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
276
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
277
+
278
+ # update generated ids, model inputs, and length for next step
279
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
280
+ model_kwargs = self._update_model_kwargs_for_generation(
281
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
282
+ )
283
+ yield input_ids
284
+ # torch.cuda.empty_cache()
285
+ # if eos_token was found in one sentence, set sentence to finished
286
+ if eos_token_id_tensor is not None:
287
+ unfinished_sequences = unfinished_sequences.mul(
288
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
289
+ )
290
+
291
+ # stop when each sentence is finished, or if we exceed the maximum length
292
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
293
+ if not synced_gpus:
294
+ break
295
+ else:
296
+ this_peer_finished = True
297
+ return input_ids
298
+
299
+ def empty_cache(self):
300
+ torch.cuda.empty_cache()
301
+
302
+ def beam_sample(
303
+ self,
304
+ input_ids,
305
+ logits_processor,
306
+ logits_warper,
307
+ stopping_criteria,
308
+ generation_config,
309
+ synced_gpus,
310
+ **model_kwargs,
311
+ ):
312
+ bos_token_id, eos_token_id, pad_token_id = (
313
+ generation_config.bos_token_id,
314
+ generation_config.eos_token_id,
315
+ generation_config.pad_token_id,
316
+ )
317
+ if isinstance(eos_token_id, int):
318
+ eos_token_id = [eos_token_id]
319
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
320
+ num_beams = generation_config.num_beams
321
+ batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1]
322
+ beam_scorer = BeamSearchScorer(
323
+ batch_size=batch_size,
324
+ num_beams=generation_config.num_beams,
325
+ device=input_ids.device,
326
+ length_penalty=generation_config.length_penalty,
327
+ do_early_stopping=generation_config.early_stopping,
328
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
329
+ max_length=generation_config.max_length,
330
+ )
331
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
332
+ input_ids=input_ids,
333
+ expand_size=generation_config.num_beams * generation_config.num_return_sequences,
334
+ is_encoder_decoder=self.config.is_encoder_decoder,
335
+ **model_kwargs,
336
+ )
337
+ scores = ()
338
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ this_peer_finished = False # used by synced_gpus only
342
+ while True:
343
+ if synced_gpus:
344
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
345
+ # The following logic allows an early break if all peers finished generating their sequence
346
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
347
+ # send 0.0 if we finished, 1.0 otherwise
348
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
349
+ # did all peers finish? the reduced sum will be 0.0 then
350
+ if this_peer_finished_flag.item() == 0.0:
351
+ break
352
+
353
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
354
+ outputs = self(
355
+ **model_inputs,
356
+ return_dict=True,
357
+ )
358
+
359
+ if synced_gpus and this_peer_finished:
360
+ cur_len = cur_len + 1
361
+ continue # don't waste resources running the code we don't need
362
+
363
+ next_token_logits = outputs.logits[:, -1, :]
364
+
365
+ # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
366
+ # cannot be generated both before and after the `nn.functional.log_softmax` operation.
367
+ next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
368
+ next_token_scores = nn.functional.log_softmax(
369
+ next_token_logits, dim=-1
370
+ ) # (batch_size * num_beams, vocab_size)
371
+
372
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
373
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
374
+ # Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
375
+ # (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
376
+ # https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
377
+ next_token_scores = logits_warper(input_ids, next_token_scores)
378
+
379
+ # reshape for beam search
380
+ vocab_size = next_token_scores.shape[-1]
381
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
382
+
383
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
384
+
385
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
386
+ next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
387
+
388
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
389
+ next_tokens = torch.gather(next_tokens, -1, _indices)
390
+
391
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
392
+ next_tokens = next_tokens % vocab_size
393
+
394
+ # stateless
395
+ beam_outputs = beam_scorer.process(
396
+ input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=None,
403
+ )
404
+ beam_scores = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
409
+ yield input_ids
410
+ model_kwargs = self._update_model_kwargs_for_generation(
411
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
412
+ )
413
+ if model_kwargs["past_key_values"] is not None:
414
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
415
+
416
+ # increase cur_len
417
+ cur_len = cur_len + 1
418
+
419
+ if beam_scorer.is_done or stopping_criteria(input_ids, scores):
420
+ if not synced_gpus:
421
+ break
422
+ else:
423
+ this_peer_finished = True
424
+
425
+ sequence_outputs = beam_scorer.finalize(
426
+ input_ids,
427
+ beam_scores,
428
+ next_tokens,
429
+ next_indices,
430
+ pad_token_id=pad_token_id,
431
+ eos_token_id=eos_token_id,
432
+ max_length=stopping_criteria.max_length,
433
+ beam_indices=None,
434
+ )
435
+ yield sequence_outputs["sequences"]
436
+
437
+ def greedy_search(
438
+ self,
439
+ input_ids,
440
+ logits_processor,
441
+ stopping_criteria,
442
+ generation_config,
443
+ synced_gpus,
444
+ **model_kwargs,
445
+ ):
446
+ # init values
447
+ bos_token_id, eos_token_id, pad_token_id = (
448
+ generation_config.bos_token_id,
449
+ generation_config.eos_token_id,
450
+ generation_config.pad_token_id,
451
+ )
452
+ if isinstance(eos_token_id, int):
453
+ eos_token_id = [eos_token_id]
454
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
455
+ # init attention / hidden states / scores tuples
456
+ scores = ()
457
+ # keep track of which sequences are already finished
458
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
459
+ this_peer_finished = False # used by synced_gpus only
460
+ while True:
461
+ if synced_gpus:
462
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
463
+ # The following logic allows an early break if all peers finished generating their sequence
464
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
465
+ # send 0.0 if we finished, 1.0 otherwise
466
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
467
+ # did all peers finish? the reduced sum will be 0.0 then
468
+ if this_peer_finished_flag.item() == 0.0:
469
+ break
470
+
471
+ # prepare model inputs
472
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
473
+ # forward pass to get next token
474
+ outputs = self(
475
+ **model_inputs,
476
+ return_dict=True,
477
+ )
478
+
479
+ if synced_gpus and this_peer_finished:
480
+ continue # don't waste resources running the code we don't need
481
+
482
+ next_token_logits = outputs.logits[:, -1, :]
483
+ # pre-process distribution
484
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
485
+ # argmax
486
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
487
+ # finished sentences should have their next token be a padding token
488
+ if eos_token_id is not None:
489
+ if pad_token_id is None:
490
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
491
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
492
+ # update generated ids, model inputs, and length for next step
493
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
494
+ model_kwargs = self._update_model_kwargs_for_generation(
495
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
496
+ )
497
+ yield input_ids
498
+ # if eos_token was found in one sentence, set sentence to finished
499
+ if eos_token_id_tensor is not None:
500
+ unfinished_sequences = unfinished_sequences.mul(
501
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
502
+ )
503
+
504
+ # stop when each sentence is finished, or if we exceed the maximum length
505
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
506
+ if not synced_gpus:
507
+ break
508
+ else:
509
+ this_peer_finished = True
510
+ yield input_ids
511
+
512
+ def beam_search(
513
+ self,
514
+ generation_config,
515
+ input_ids,
516
+ logits_processor,
517
+ stopping_criteria,
518
+ synced_gpus,
519
+ **model_kwargs,
520
+ ):
521
+ # 10. go into beam search generation modes
522
+ # 11. prepare beam search scorer
523
+ bos_token_id, eos_token_id, pad_token_id = (
524
+ generation_config.bos_token_id,
525
+ generation_config.eos_token_id,
526
+ generation_config.pad_token_id,
527
+ )
528
+ if isinstance(eos_token_id, int):
529
+ eos_token_id = [eos_token_id]
530
+ num_beams = generation_config.num_beams
531
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
532
+ beam_scorer = BeamSearchScorer(
533
+ batch_size=batch_size,
534
+ num_beams=generation_config.num_beams,
535
+ device=input_ids.device,
536
+ length_penalty=generation_config.length_penalty,
537
+ do_early_stopping=generation_config.early_stopping,
538
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
539
+ max_length=generation_config.max_length,
540
+ )
541
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
542
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
543
+ input_ids=input_ids,
544
+ expand_size=generation_config.num_beams,
545
+ is_encoder_decoder=self.config.is_encoder_decoder,
546
+ **model_kwargs,
547
+ )
548
+ # beam_search logits
549
+ batch_beam_size, cur_len = input_ids.shape
550
+ if num_beams * batch_size != batch_beam_size:
551
+ raise ValueError(
552
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
553
+ )
554
+ beam_scores = torch.zeros(
555
+ (batch_size, num_beams), dtype=torch.float, device=input_ids.device
556
+ )
557
+ beam_scores[:, 1:] = -1e9
558
+ beam_scores = beam_scores.view((batch_size * num_beams,))
559
+ this_peer_finished = False # used by synced_gpus only
560
+ while True:
561
+ if synced_gpus:
562
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
563
+ # The following logic allows an early break if all peers finished generating their sequence
564
+ this_peer_finished_flag = torch.tensor(
565
+ 0.0 if this_peer_finished else 1.0
566
+ ).to(input_ids.device)
567
+ # send 0.0 if we finished, 1.0 otherwise
568
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
569
+ # did all peers finish? the reduced sum will be 0.0 then
570
+ if this_peer_finished_flag.item() == 0.0:
571
+ break
572
+
573
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
574
+ outputs = self(
575
+ **model_inputs,
576
+ return_dict=True,
577
+ output_attentions=False,
578
+ output_hidden_states=False,
579
+ )
580
+
581
+ if synced_gpus and this_peer_finished:
582
+ cur_len = cur_len + 1
583
+ continue # don't waste resources running the code we don't need
584
+
585
+ next_token_logits = outputs.logits[:, -1, :]
586
+ # next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian.
587
+ next_token_scores = nn.functional.log_softmax(
588
+ next_token_logits, dim=-1
589
+ ) # (batch_size * num_beams, vocab_size)
590
+ next_token_scores_processed = logits_processor(input_ids, next_token_scores)
591
+ next_token_scores = next_token_scores_processed + beam_scores[
592
+ :, None
593
+ ].expand_as(next_token_scores)
594
+
595
+ # reshape for beam search
596
+ vocab_size = next_token_scores.shape[-1]
597
+ next_token_scores = next_token_scores.view(
598
+ batch_size, num_beams * vocab_size
599
+ )
600
+
601
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
602
+ next_token_scores, next_tokens = torch.topk(
603
+ next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
604
+ )
605
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
606
+ next_tokens = next_tokens % vocab_size
607
+ # stateless
608
+ beam_outputs = beam_scorer.process(
609
+ input_ids,
610
+ next_token_scores,
611
+ next_tokens,
612
+ next_indices,
613
+ pad_token_id=pad_token_id,
614
+ eos_token_id=eos_token_id,
615
+ beam_indices=None,
616
+ )
617
+ beam_scores = beam_outputs["next_beam_scores"]
618
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
619
+ beam_idx = beam_outputs["next_beam_indices"]
620
+
621
+ input_ids = torch.cat(
622
+ [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
623
+ )
624
+ model_kwargs = self._update_model_kwargs_for_generation(
625
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
626
+ )
627
+ if model_kwargs["past_key_values"] is not None:
628
+ model_kwargs["past_key_values"] = self._reorder_cache(
629
+ model_kwargs["past_key_values"], beam_idx
630
+ )
631
+
632
+ # increase cur_len
633
+ cur_len = cur_len + 1
634
+
635
+ yield input_ids
636
+
637
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
638
+ if not synced_gpus:
639
+ break
640
+ else:
641
+ this_peer_finished = True
642
+
643
+ final_result = beam_scorer.finalize(
644
+ input_ids,
645
+ beam_scores,
646
+ next_tokens,
647
+ next_indices,
648
+ pad_token_id=pad_token_id,
649
+ eos_token_id=eos_token_id,
650
+ max_length=stopping_criteria.max_length,
651
+ beam_indices=None,
652
+ )
653
+ yield final_result["sequences"]
654
+