sethuiyer commited on
Commit
85f4970
1 Parent(s): 4de3c00

Delete entropic_cot.py

Browse files
Files changed (1) hide show
  1. entropic_cot.py +0 -245
entropic_cot.py DELETED
@@ -1,245 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from typing import List, Dict, Tuple, Optional, NamedTuple
5
- from enum import Enum, auto
6
- from dataclasses import dataclass
7
- import warnings
8
- warnings.filterwarnings("ignore", category=FutureWarning)
9
-
10
-
11
- class DecoderState(Enum):
12
- GREEDY_UNTIL_NEWLINE = auto()
13
- SELECT_AFTER_NEWLINE = auto()
14
- TERMINATED = auto()
15
-
16
- class CacheState(NamedTuple):
17
- past_key_values: Tuple
18
- last_position: int
19
-
20
- @dataclass
21
- class GenerationState:
22
- tokens: torch.Tensor
23
- attention_mask: torch.Tensor
24
- cache_state: Optional[CacheState] = None
25
- entropy_diffs: List[float] = None
26
- current_length: int = 0
27
- _token_buffer: Optional[torch.Tensor] = None
28
- _attn_buffer: Optional[torch.Tensor] = None
29
-
30
- def __post_init__(self):
31
- self.entropy_diffs = []
32
- # Pre-allocate buffers for token and attention mask growth
33
- max_length = self.tokens.size(1) + 1024 # reasonable buffer size
34
- self._token_buffer = torch.zeros(
35
- (1, max_length),
36
- dtype=self.tokens.dtype,
37
- device=self.tokens.device
38
- )
39
- self._attn_buffer = torch.ones(
40
- (1, max_length),
41
- dtype=self.attention_mask.dtype,
42
- device=self.attention_mask.device
43
- )
44
- # Copy initial tokens and attention mask
45
- self._token_buffer[:, :self.tokens.size(1)] = self.tokens
46
- self._attn_buffer[:, :self.attention_mask.size(1)] = self.attention_mask
47
-
48
- def extend(self, new_token: torch.Tensor):
49
- """Efficient in-place extension of state"""
50
- current_len = self.tokens.size(1)
51
- if len(new_token.shape) == 0:
52
- new_token = new_token.unsqueeze(0)
53
-
54
- # Use pre-allocated buffers
55
- self._token_buffer[:, current_len] = new_token
56
- self.tokens = self._token_buffer[:, :current_len + 1]
57
- self.attention_mask = self._attn_buffer[:, :current_len + 1]
58
- self.current_length += 1
59
-
60
- class SpeculativeDecoder:
61
- def __init__(
62
- self,
63
- model: AutoModelForCausalLM,
64
- tokenizer: AutoTokenizer,
65
- device: Optional[torch.device] = None,
66
- max_new_tokens: int = 512,
67
- k: int = 3,
68
- use_cache: bool = True
69
- ):
70
- self.model = model
71
- self.tokenizer = tokenizer
72
- self.device = device or next(model.parameters()).device
73
- self.max_new_tokens = max_new_tokens
74
- self.k = k
75
- self.use_cache = use_cache
76
-
77
- # Pre-compute constants
78
- self.newline_token = tokenizer.encode("\n", add_special_tokens=False)[0]
79
- if tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = tokenizer.eos_token_id
81
-
82
- # Pre-allocate reusable tensors
83
- self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device)
84
-
85
- # Prepare model for inference
86
- if hasattr(model, 'eval'):
87
- model.eval()
88
-
89
- # Enable Flash Attention if available
90
- if hasattr(model, 'enable_flash_attention'):
91
- try:
92
- model.enable_flash_attention()
93
- except Exception as e:
94
- warnings.warn(f"Failed to enable Flash Attention: {e}")
95
-
96
- @staticmethod
97
- @torch.jit.script
98
- def calculate_entropy(probs: torch.Tensor) -> torch.Tensor:
99
- """JIT-compiled entropy calculation"""
100
- return -torch.sum(probs * torch.log2(probs + 1e-12), dim=-1)
101
-
102
- def set_k(self, k: int):
103
- self.k = k
104
- self.batch_attention_mask = torch.ones(k, 1, dtype=torch.long, device=self.device)
105
-
106
- def prepare_inputs(self, messages: List[Dict[str, str]]) -> torch.Tensor:
107
- """Efficient input preparation"""
108
- if hasattr(self.tokenizer, 'chat_template'):
109
- input_text = self.tokenizer.apply_chat_template(
110
- messages,
111
- tokenize=False,
112
- add_generation_prompt=True
113
- )
114
- else:
115
- input_text = "\n".join(f"{msg['role']}: {msg['content']}" for msg in messages) + "\nassistant:"
116
-
117
- return self.tokenizer(
118
- input_text,
119
- return_tensors="pt",
120
- padding=False
121
- ).input_ids.to(self.device)
122
-
123
- def select_least_entropic_token(self, state: GenerationState) -> Tuple[torch.Tensor, float]:
124
- """Optimized token selection with vectorized operations"""
125
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
126
- # Initial logits computation
127
- outputs = self.model(
128
- input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens,
129
- attention_mask=state.attention_mask,
130
- past_key_values=state.cache_state.past_key_values if state.cache_state else None,
131
- use_cache=True
132
- )
133
-
134
- state.cache_state = CacheState(outputs.past_key_values, state.tokens.size(1)) if self.use_cache else None
135
-
136
- # Efficient top-k selection
137
- logits = outputs.logits[0, -1]
138
- top_k_probs, top_k_indices = torch.topk(F.softmax(logits, dim=-1), self.k)
139
-
140
- # Prepare batch inputs
141
- batch_tokens = top_k_indices.unsqueeze(1)
142
-
143
- # Efficient cache expansion
144
- if state.cache_state:
145
- batch_past_kv = [
146
- (
147
- layer_past[0].expand(self.k, -1, -1, -1),
148
- layer_past[1].expand(self.k, -1, -1, -1)
149
- )
150
- for layer_past in state.cache_state.past_key_values
151
- ]
152
- else:
153
- batch_past_kv = None
154
-
155
- # Single forward pass for all candidates
156
- batch_outputs = self.model(
157
- input_ids=batch_tokens,
158
- attention_mask=self.batch_attention_mask,
159
- past_key_values=batch_past_kv,
160
- use_cache=True,
161
- output_attentions=True
162
- )
163
-
164
- # Efficient attention processing
165
- middle_layer = len(batch_outputs.attentions) // 2
166
- batch_attn_probs = F.softmax(
167
- batch_outputs.attentions[middle_layer][:, :, -1, :],
168
- dim=-1
169
- )
170
-
171
- # Vectorized entropy calculation
172
- old_entropy = self.calculate_entropy(batch_attn_probs[:, :, :-1])
173
- new_entropy = self.calculate_entropy(batch_attn_probs)
174
-
175
- # Efficient difference calculation
176
- entropy_var = torch.var(
177
- torch.stack([old_entropy, new_entropy]),
178
- dim=0,
179
- keepdim=True
180
- ) + 1e-6
181
- diffs = ((new_entropy - old_entropy) / entropy_var).mean(dim=-1).squeeze(0)
182
- min_idx = diffs.argmin()
183
- return top_k_indices[min_idx].unsqueeze(0), diffs[min_idx].item()
184
-
185
- def greedy_decode(self, state: GenerationState) -> torch.Tensor:
186
- """Optimized greedy decoding"""
187
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
188
- outputs = self.model(
189
- input_ids=state.tokens[:, -1:] if state.cache_state else state.tokens,
190
- attention_mask=state.attention_mask,
191
- past_key_values=state.cache_state.past_key_values if state.cache_state else None,
192
- use_cache=True
193
- )
194
-
195
- state.cache_state = CacheState(
196
- outputs.past_key_values,
197
- state.tokens.size(1)
198
- ) if self.use_cache else None
199
-
200
- return outputs.logits[0, -1].argmax()
201
-
202
- def __call__(
203
- self,
204
- messages: List[Dict[str, str]]
205
- ) -> Tuple[str, float]:
206
- """Main decoding loop with optimized state transitions"""
207
- input_ids = self.prepare_inputs(messages)
208
-
209
- state = GenerationState(
210
- tokens=input_ids,
211
- attention_mask=torch.ones_like(input_ids)
212
- )
213
-
214
- current_state = DecoderState.SELECT_AFTER_NEWLINE
215
-
216
- while current_state != DecoderState.TERMINATED and state.current_length < self.max_new_tokens:
217
- if current_state == DecoderState.SELECT_AFTER_NEWLINE:
218
- next_token, entropy_diff = self.select_least_entropic_token(state)
219
- state.entropy_diffs.append(entropy_diff)
220
- current_state = DecoderState.GREEDY_UNTIL_NEWLINE
221
-
222
- else: # GREEDY_UNTIL_NEWLINE
223
- next_token = self.greedy_decode(state)
224
-
225
- if next_token.item() == self.tokenizer.eos_token_id:
226
- current_state = DecoderState.TERMINATED
227
- elif next_token.item() == self.newline_token:
228
- current_state = DecoderState.SELECT_AFTER_NEWLINE
229
-
230
- state.extend(next_token)
231
-
232
- # Efficient post-processing
233
- generated_ids = state.tokens[0, len(input_ids[0]):]
234
- generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
235
-
236
- # Vectorized score calculation
237
- if state.entropy_diffs:
238
- avg_entropy_diff = torch.tensor(state.entropy_diffs).mean().item()
239
- else:
240
- avg_entropy_diff = 1.0
241
-
242
- completion_ratio = len(generated_ids) / self.max_new_tokens
243
- score = (1.0 / (avg_entropy_diff/100 + 1e-12)) * completion_ratio
244
-
245
- return generated_text, round(score ** 0.33, 4)