myduy commited on
Commit
14c9e5e
·
verified ·
1 Parent(s): 7385b17

Upload generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +464 -0
generation_utils.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ import copy
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.distributions as dists
23
+ from torch.nn import functional as F
24
+ from transformers import __version__
25
+ from transformers.generation.configuration_utils import (
26
+ GenerationConfig
27
+ )
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ is_torchdynamo_compiling,
31
+ logging,
32
+ )
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def top_p_logits(logits, top_p=None):
38
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
+ sorted_indices_to_remove = cumulative_probs > top_p
41
+ # Shift the indices to the right to keep the first token above the threshold
42
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
+ sorted_indices_to_remove[..., 0] = 0
44
+
45
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
46
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
47
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
+ return logits
49
+
50
+ def top_k_logits(logits, top_k=None):
51
+ top_k = min(top_k, logits.size(-1)) # Safety check
52
+ # Remove all tokens with a probability less than the last token of the top-k
53
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
+ return logits
56
+
57
+
58
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
+
60
+ if temperature > 0:
61
+ logits = logits / temperature
62
+ if top_p is not None and top_p < 1:
63
+ logits = top_p_logits(logits, top_p)
64
+ if top_k is not None:
65
+ logits = top_k_logits(logits, top_k)
66
+ probs = torch.softmax(logits, dim=-1)
67
+
68
+ if temperature > 0:
69
+ try:
70
+ x0 = dists.Categorical(probs=probs).sample()
71
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
+ except:
73
+ confidence, x0 = probs.max(dim=-1)
74
+ else:
75
+ confidence, x0 = probs.max(dim=-1)
76
+
77
+ if margin_confidence:
78
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
+ # Extract top1 and top2 probabilities
80
+ top1_probs = sorted_probs[:, 0]
81
+ top2_probs = sorted_probs[:, 1]
82
+ # Calculate confidence as top1 - top2
83
+ confidence = top1_probs - top2_probs
84
+
85
+ if neg_entropy:
86
+ epsilon = 1e-10
87
+ log_probs = torch.log(probs + epsilon)
88
+ confidence = torch.sum(probs * log_probs, dim=-1)
89
+
90
+ return confidence, x0
91
+
92
+
93
+ @dataclass
94
+ class DreamModelOutput(ModelOutput):
95
+ sequences: torch.LongTensor = None
96
+ history: Optional[Tuple[torch.FloatTensor]] = None
97
+
98
+
99
+ class DreamGenerationConfig(GenerationConfig):
100
+ def __init__(self, **kwargs):
101
+ self.temperature: float = kwargs.pop("temperature", 0.0)
102
+ self.top_p: Optional[float] = kwargs.pop("top_p", None)
103
+ self.top_k: Optional[int] = kwargs.pop("top_k", None)
104
+ self.max_length = kwargs.pop("max_length", 20)
105
+ self.max_new_tokens = kwargs.pop("max_new_tokens", None)
106
+ # diffusion specific params
107
+ self.eps: float = kwargs.pop("eps", 1e-3)
108
+ self.steps: int = kwargs.pop("steps", 512)
109
+ self.alg: str = kwargs.pop("alg", 'origin')
110
+ self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
+
112
+ # Parameters that define the output variables of `generate`
113
+ self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
+ self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
115
+ self.output_history: bool = kwargs.pop("output_history", False)
116
+
117
+ # Special tokens that can be used at generation time
118
+ self.mask_token_id = kwargs.pop("mask_token_id", None)
119
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
120
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
121
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
122
+
123
+ # Wild card
124
+ self.generation_kwargs = kwargs.pop("generation_kwargs", {})
125
+
126
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
127
+ # interface.
128
+ self._from_model_config = kwargs.pop("_from_model_config", False)
129
+ self._commit_hash = kwargs.pop("_commit_hash", None)
130
+ self.transformers_version = kwargs.pop("transformers_version", __version__)
131
+
132
+ # Additional attributes without default values
133
+ if not self._from_model_config:
134
+ # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
135
+ # model's default configuration file
136
+ for key, value in kwargs.items():
137
+ try:
138
+ setattr(self, key, value)
139
+ except AttributeError as err:
140
+ logger.error(f"Can't set {key} with value {value} for {self}")
141
+ raise err
142
+
143
+ # Validate the values of the attributes
144
+ self.validate(is_init=True)
145
+
146
+ def validate(self, is_init=False):
147
+ pass
148
+
149
+ class DreamGenerationMixin:
150
+ @staticmethod
151
+ def _expand_inputs_for_generation(
152
+ expand_size: int = 1,
153
+ input_ids: Optional[torch.LongTensor] = None,
154
+ attention_mask: Optional[torch.LongTensor] = None
155
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
156
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
157
+ # Do not call torch.repeat_interleave if expand_size is 1 because it clones
158
+ # the input tensor and thus requires more memory although no change is applied
159
+ if expand_size == 1:
160
+ return input_ids, attention_mask
161
+ if input_ids is not None:
162
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
163
+ if attention_mask is not None:
164
+ attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
+ return input_ids, attention_mask
166
+
167
+ def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
+ """Performs validation related to the resulting generated length"""
169
+
170
+ # Can't throw warnings/exceptions during compilation
171
+ if is_torchdynamo_compiling():
172
+ return
173
+
174
+ # 1. Max length warnings related to poor parameterization
175
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
176
+ # 20 is the default max_length of the generation config
177
+ warnings.warn(
178
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
179
+ "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
180
+ "generation.",
181
+ UserWarning,
182
+ )
183
+ if input_ids_length >= generation_config.max_length:
184
+ input_ids_string = "input_ids"
185
+ raise ValueError(
186
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
187
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
188
+ " increasing `max_length` or, better yet, setting `max_new_tokens`."
189
+ )
190
+
191
+ def _prepare_generated_length(
192
+ self,
193
+ generation_config,
194
+ has_default_max_length,
195
+ input_ids_length,
196
+ ):
197
+ """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
198
+
199
+ if generation_config.max_new_tokens is not None:
200
+ if not has_default_max_length and generation_config.max_length is not None:
201
+ logger.warning(
202
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
203
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
204
+ "Please refer to the documentation for more information. "
205
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
206
+ )
207
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_length
208
+
209
+ elif has_default_max_length:
210
+ if generation_config.max_length == DreamGenerationConfig().max_length:
211
+ generation_config.max_length = generation_config.max_length + input_ids_length
212
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
213
+ if max_position_embeddings is not None:
214
+ generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
215
+
216
+ return generation_config
217
+
218
+ def _prepare_generation_config(
219
+ self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
220
+ ) -> DreamGenerationConfig:
221
+ """
222
+ Prepares the base generation config, then applies any generation configuration options from kwargs. This
223
+ function handles retrocompatibility with respect to configuration files.
224
+ """
225
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
226
+ using_model_generation_config = False
227
+ if generation_config is None:
228
+ generation_config = DreamGenerationConfig.from_model_config(self.config)
229
+ using_model_generation_config = True
230
+
231
+ # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
232
+ # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
233
+ # exception will be raised in `_validate_model_kwargs`
234
+ if not is_torchdynamo_compiling():
235
+ generation_config = copy.deepcopy(generation_config)
236
+ _kwargs = generation_config.update(**kwargs)
237
+ # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
238
+ if not using_model_generation_config:
239
+ if generation_config.bos_token_id is None:
240
+ generation_config.bos_token_id = self.generation_config.bos_token_id
241
+ if generation_config.eos_token_id is None:
242
+ generation_config.eos_token_id = self.generation_config.eos_token_id
243
+ if generation_config.pad_token_id is None:
244
+ generation_config.pad_token_id = self.generation_config.pad_token_id
245
+ if generation_config.mask_token_id is None:
246
+ generation_config.mask_token_id = self.generation_config.mask_token_id
247
+
248
+ return generation_config
249
+
250
+ def _prepare_special_tokens(
251
+ self,
252
+ generation_config: DreamGenerationConfig,
253
+ device: Optional[Union[torch.device, str]] = None,
254
+ ):
255
+ """
256
+ Prepares the special tokens for generation, overwriting the generation config with their processed versions
257
+ converted to tensor.
258
+
259
+ Note that `generation_config` is changed in place and stops being serializable after this method is called.
260
+ That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
261
+ function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
262
+ """
263
+
264
+ # Convert special tokens to tensors
265
+ def _tensor_or_none(token, device=None):
266
+ if token is None:
267
+ return token
268
+
269
+ device = device if device is not None else self.device
270
+ if isinstance(token, torch.Tensor):
271
+ return token.to(device)
272
+ return torch.tensor(token, device=device, dtype=torch.long)
273
+
274
+ bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
275
+ eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
276
+ pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
277
+ mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
278
+
279
+ # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
280
+ if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
281
+ eos_token_tensor = eos_token_tensor.unsqueeze(0)
282
+
283
+ # Set pad token if unset (and there are conditions to do so)
284
+ if pad_token_tensor is None and eos_token_tensor is not None:
285
+ pad_token_tensor = eos_token_tensor[0]
286
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
287
+
288
+ # Update generation config with the updated special tokens tensors
289
+ # NOTE: this must be written into a different attribute name than the one holding the original special tokens
290
+ # (in their non-tensor form), in order to enable end-to-end compilation. See
291
+ # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
292
+ generation_config._bos_token_tensor = bos_token_tensor
293
+ generation_config._eos_token_tensor = eos_token_tensor
294
+ generation_config._pad_token_tensor = pad_token_tensor
295
+ generation_config._mask_token_tensor = mask_token_tensor
296
+
297
+ @torch.no_grad()
298
+ def diffusion_generate(
299
+ self,
300
+ inputs: Optional[torch.Tensor] = None,
301
+ generation_config: Optional[DreamGenerationConfig] = None,
302
+ **kwargs,
303
+ ) -> Union[DreamModelOutput, torch.LongTensor]:
304
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
305
+ generation_config = self._prepare_generation_config(generation_config, **kwargs)
306
+ generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
307
+ generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
308
+
309
+ # 2. Define model inputs
310
+ assert inputs is not None
311
+ input_ids = inputs
312
+ device = input_ids.device
313
+ attention_mask = kwargs.pop("attention_mask", None)
314
+ self._prepare_special_tokens(generation_config, device=device)
315
+
316
+ # 3. Prepare `max_length`.
317
+ input_ids_length = input_ids.shape[-1]
318
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
319
+ generation_config = self._prepare_generated_length(
320
+ generation_config=generation_config,
321
+ has_default_max_length=has_default_max_length,
322
+ input_ids_length=input_ids_length,
323
+ )
324
+
325
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
326
+
327
+ # 4. Check input_ids
328
+ if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
329
+ warnings.warn(
330
+ "You are calling .generate() with the `input_ids` being on a device type different"
331
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
332
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
333
+ " Please make sure that you have put `input_ids` to the"
334
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
335
+ " running `.generate()`.",
336
+ UserWarning,
337
+ )
338
+ if (
339
+ hasattr(generation_config, "pad_token_id") and
340
+ torch.any(input_ids == generation_config.pad_token_id) and
341
+ attention_mask is None
342
+ ):
343
+ warnings.warn(
344
+ "Padding was detected but no attention mask is passed here. For correct "
345
+ "generation results, please set `attention_mask` when batch-padding inputs.",
346
+ UserWarning,
347
+ )
348
+
349
+ input_ids, attention_mask = self._expand_inputs_for_generation(
350
+ expand_size=generation_config.num_return_sequences,
351
+ input_ids=input_ids,
352
+ attention_mask=attention_mask
353
+ )
354
+
355
+ result = self._sample(
356
+ input_ids,
357
+ attention_mask=attention_mask,
358
+ generation_config=generation_config,
359
+ generation_tokens_hook_func=generation_tokens_hook_func,
360
+ generation_logits_hook_func=generation_logits_hook_func
361
+ )
362
+ return result
363
+
364
+ def _sample(
365
+ self,
366
+ input_ids: torch.LongTensor,
367
+ attention_mask: Optional[torch.LongTensor],
368
+ generation_config: DreamGenerationConfig,
369
+ generation_tokens_hook_func,
370
+ generation_logits_hook_func
371
+ ) -> Union[DreamModelOutput, torch.LongTensor]:
372
+ # init values
373
+ output_history = generation_config.output_history
374
+ return_dict_in_generate = generation_config.return_dict_in_generate
375
+ max_length = generation_config.max_length
376
+ mask_token_id = generation_config.mask_token_id
377
+ steps = generation_config.steps
378
+ eps = generation_config.eps
379
+ alg = generation_config.alg
380
+ alg_temp = generation_config.alg_temp
381
+ temperature = generation_config.temperature
382
+ top_p = generation_config.top_p
383
+ top_k = generation_config.top_k
384
+
385
+ histories = [] if (return_dict_in_generate and output_history) else None
386
+
387
+ # pad input_ids to max_length
388
+ x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
389
+
390
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
391
+ # we do not mask the [MASK] tokens so value = 1.0
392
+ attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
393
+ tok_idx = attention_mask.long().cumsum(-1) - 1
394
+ tok_idx.masked_fill_(attention_mask == 0, 1)
395
+ # attention_mask is of shape [B, N]
396
+ # broadcast to [B, 1, N, N]
397
+ attention_mask = torch.logical_and(
398
+ attention_mask.unsqueeze(1).unsqueeze(-2),
399
+ attention_mask.unsqueeze(1).unsqueeze(-1),
400
+ )
401
+ else:
402
+ tok_idx = None
403
+ attention_mask = "full"
404
+
405
+ timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
406
+
407
+ # this allows user-defined token control of the intermediate steps
408
+ x = generation_tokens_hook_func(None, x, None)
409
+ for i in range(steps):
410
+ mask_index = (x == mask_token_id)
411
+ logits = self(x, attention_mask, tok_idx).logits
412
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
413
+
414
+ # this allows user-defined logits control of the intermediate steps
415
+ logits = generation_logits_hook_func(i, x, logits)
416
+
417
+ mask_logits = logits[mask_index]
418
+ t = timesteps[i]
419
+ s = timesteps[i + 1]
420
+
421
+ if alg == 'origin':
422
+ p_transfer = 1 - s / t if i < steps - 1 else 1
423
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
424
+ transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
425
+ _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
426
+ x[mask_index] = x0.clone()
427
+ else:
428
+ if alg == 'maskgit_plus':
429
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
430
+ elif alg == 'topk_margin':
431
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
432
+ elif alg == 'entropy':
433
+ confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
+ else:
435
+ raise RuntimeError(f"Unknown alg: {alg}")
436
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
437
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
+ full_confidence[mask_index] = confidence
440
+ if number_transfer_tokens > 0:
441
+ if alg_temp is None or alg_temp == 0:
442
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
+ else:
444
+ full_confidence = full_confidence / alg_temp
445
+ full_confidence = F.softmax(full_confidence, dim=-1)
446
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
+ x_[mask_index] = x0.clone()
449
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
451
+
452
+ # this allows user-defined token control of the intermediate steps
453
+ x = generation_tokens_hook_func(i, x, logits)
454
+
455
+ if histories is not None:
456
+ histories.append(x.clone())
457
+
458
+ if return_dict_in_generate:
459
+ return DreamModelOutput(
460
+ sequences=x,
461
+ history=histories,
462
+ )
463
+ else:
464
+ return x