athms commited on
Commit
cac427a
·
verified ·
1 Parent(s): 3e7a3bf

Delete generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +0 -149
generation_utils.py DELETED
@@ -1,149 +0,0 @@
1
- """
2
- RND1 Generation Utilities.
3
-
4
- This module provides generation utilities and mixins for RND1 models,
5
- including the main GenerationMixin class that integrates with HuggingFace.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- from typing import Optional, Union, Dict, Any
11
- from transformers import GenerationMixin as HFGenerationMixin
12
- from transformers.generation import GenerationConfig
13
-
14
- from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
15
-
16
-
17
- class RND1GenerationMixin(HFGenerationMixin):
18
- """
19
- Generation mixin for RND1 models.
20
-
21
- This mixin provides generation methods compatible with HuggingFace's
22
- generation API while using RND1's diffusion-based sampling internally.
23
- """
24
-
25
- def generate(
26
- self,
27
- inputs: Optional[torch.LongTensor] = None,
28
- generation_config: Optional[GenerationConfig] = None,
29
- # RND1-specific parameters
30
- prefix_ids: Optional[torch.LongTensor] = None,
31
- suffix_ids: Optional[torch.LongTensor] = None,
32
- infill_length: Optional[int] = None,
33
- return_dict_in_generate: Optional[bool] = None,
34
- **kwargs, # Accept all kwargs to be compatible with pipelines
35
- ) -> Union[torch.LongTensor, Dict[str, Any]]:
36
- """
37
- Generate text using RND1's diffusion-based sampling.
38
-
39
- Follows HuggingFace's standard generate API, using diffusion sampling
40
- internally. Supports both standard generation and infilling.
41
-
42
- Args:
43
- inputs: Input token IDs to use as prefix (standard HF parameter)
44
- generation_config: Generation configuration object
45
- prefix_ids: Alternative to inputs for infilling tasks
46
- suffix_ids: Optional suffix for infilling tasks
47
- infill_length: Length of infill region (for infilling)
48
- return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
49
- **kwargs: Additional arguments (accepted for compatibility)
50
-
51
- Returns:
52
- Generated token IDs or GenerateDecoderOnlyOutput
53
- """
54
- if generation_config is not None:
55
- gen_config = generation_config
56
- model_kwargs = kwargs.copy()
57
- else:
58
- # Only prepare config from kwargs if no config was provided
59
- gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
60
-
61
- device = next(self.parameters()).device
62
-
63
- if inputs is not None:
64
- prefix_ids = inputs.to(device)
65
- elif prefix_ids is not None:
66
- prefix_ids = prefix_ids.to(device)
67
- else:
68
- prefix_ids = None
69
-
70
- if suffix_ids is not None:
71
- suffix_ids = suffix_ids.to(device)
72
-
73
- eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
74
- pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
75
- bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
76
- mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
77
-
78
- if infill_length is not None and prefix_ids is not None:
79
- # Infilling mode: use specified infill_length
80
- prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
81
- suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
82
- seq_len = prefix_len + infill_length + suffix_len
83
- else:
84
- # Standard generation mode
85
- if prefix_ids is not None:
86
- prefix_len = prefix_ids.shape[1]
87
- if gen_config.max_new_tokens is not None:
88
- seq_len = prefix_len + gen_config.max_new_tokens
89
- else:
90
- seq_len = gen_config.max_length or self.config.max_position_embeddings
91
- else:
92
- seq_len = gen_config.max_length or self.config.max_position_embeddings
93
-
94
- num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
95
- getattr(self.config, "num_diffusion_steps", 256))
96
-
97
- temperature = float(getattr(gen_config, "temperature", 1.0))
98
- top_k = getattr(gen_config, "top_k", None)
99
- top_p = getattr(gen_config, "top_p", None)
100
-
101
- greedy = getattr(gen_config, "greedy",
102
- not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
103
-
104
- generator = model_kwargs.get("generator", None)
105
- if generator is None:
106
- seed = getattr(gen_config, 'seed', None)
107
- if seed is not None:
108
- generator = torch.Generator(device=device)
109
- generator.manual_seed(seed)
110
-
111
- with torch.inference_mode():
112
- sequences = diffusion_sample(
113
- model=self,
114
- seq_len=seq_len,
115
- num_steps=num_diffusion_steps,
116
- mask_token_id=mask_token_id,
117
- temperature=temperature,
118
- top_k=top_k,
119
- top_p=top_p,
120
- greedy=greedy,
121
- prefix_ids=prefix_ids,
122
- suffix_ids=suffix_ids,
123
- infill_length=infill_length,
124
- eos_token_id=eos_token_id,
125
- pad_token_id=pad_token_id,
126
- bos_token_id=bos_token_id,
127
- device=device,
128
- generator=generator,
129
- visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
130
- )
131
-
132
- if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
133
- from transformers.generation.utils import GenerateDecoderOnlyOutput
134
- return GenerateDecoderOnlyOutput(sequences=sequences)
135
-
136
- return sequences
137
-
138
- def prepare_inputs_for_generation(
139
- self,
140
- input_ids: torch.LongTensor,
141
- **kwargs,
142
- ) -> Dict[str, Any]:
143
- """
144
- Prepare inputs for generation (required by HuggingFace).
145
-
146
- For RND1, we don't use the standard autoregressive generation,
147
- so this just returns the input_ids.
148
- """
149
- return {"input_ids": input_ids}