Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901 | """ | |
| Training utilities for supervised Cortex adapter tuning. | |
| These helpers keep the base model frozen and optimize only the modules managed by | |
| CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a | |
| small tuning run optimizes the same multiple-choice objective being evaluated. | |
| """ | |
| from __future__ import annotations | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from benchmark.scoring import reset_cortex_state | |
| def continuation_log_likelihood( | |
| model, | |
| tokenizer, | |
| context: str, | |
| continuation: str, | |
| device: str, | |
| ) -> Optional[torch.Tensor]: | |
| """Differentiable average continuation log-likelihood.""" | |
| ctx_ids = tokenizer.encode(context, add_special_tokens=False) | |
| full_ids = tokenizer.encode(context + continuation, add_special_tokens=False) | |
| cont_start = len(ctx_ids) | |
| cont_length = len(full_ids) - cont_start | |
| if cont_start <= 0 or cont_length <= 0: | |
| return None | |
| input_ids = torch.tensor([full_ids], device=device) | |
| max_len = getattr(model.config, "max_position_embeddings", 2048) | |
| if input_ids.shape[1] > max_len: | |
| input_ids = input_ids[:, :max_len] | |
| cont_length = min(cont_length, max_len - cont_start) | |
| if cont_length <= 0: | |
| return None | |
| reset_cortex_state(model, batch_size=input_ids.shape[0]) | |
| outputs = model(input_ids) | |
| logits = outputs.logits | |
| shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :] | |
| shift_labels = input_ids[0, cont_start : cont_start + cont_length] | |
| log_probs = F.log_softmax(shift_logits, dim=-1) | |
| token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1) | |
| return token_log_probs.mean() | |
| def multiple_choice_loss( | |
| model, | |
| tokenizer, | |
| example: Dict, | |
| device: str, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[int]]: | |
| """ | |
| Cross-entropy over continuation log-likelihoods. | |
| Returns: | |
| (loss, prediction). If an example cannot be scored, both are None. | |
| """ | |
| scores: List[torch.Tensor] = [] | |
| for continuation in example["continuations"]: | |
| score = continuation_log_likelihood( | |
| model, tokenizer, example["context"], continuation, device | |
| ) | |
| if score is None: | |
| return None, None | |
| scores.append(score) | |
| logits = torch.stack(scores).unsqueeze(0) | |
| gold = torch.tensor([example["gold_idx"]], device=device) | |
| loss = F.cross_entropy(logits, gold) | |
| pred = int(logits.argmax(dim=-1).item()) | |
| return loss, pred | |
| def cortex_auxiliary_loss(model) -> torch.Tensor: | |
| """Collect differentiable auxiliary losses exposed by Cortex modules.""" | |
| device = next(model.parameters()).device | |
| surgeon = getattr(model, "_cortex_surgeon", None) | |
| if surgeon is None: | |
| return torch.tensor(0.0, device=device) | |
| losses = [] | |
| for module in surgeon.modules.values(): | |
| get_budget_loss = getattr(module, "get_budget_loss", None) | |
| if get_budget_loss is not None: | |
| losses.append(get_budget_loss()) | |
| if not losses: | |
| return torch.tensor(0.0, device=device) | |
| return torch.stack([loss.to(device) for loss in losses]).sum() | |