| """Minimal usage example for the guided-decoding-adapter. |
| |
| Run from the carnot repo root: |
| JAX_PLATFORMS=cpu python exports/guided-decoding-adapter/example.py |
| """ |
|
|
| import os |
| os.environ.setdefault("JAX_PLATFORMS", "cpu") |
|
|
| from unittest.mock import MagicMock |
| import torch |
|
|
| from carnot.inference.guided_decoding import GuidedDecoder |
|
|
| |
| |
| |
| decoder = GuidedDecoder.from_pretrained("exports/guided-decoding-adapter") |
|
|
| |
| |
| |
| |
| step = [0] |
|
|
| def _forward(input_ids): |
| logits = torch.zeros(1, input_ids.shape[1], 10) |
| logits[0, -1, 1 if step[0] >= 3 else 0] = 10.0 |
| step[0] += 1 |
| out = MagicMock(); out.logits = logits; return out |
|
|
| model = MagicMock() |
| model.side_effect = _forward |
| model.parameters = MagicMock(return_value=iter([torch.zeros(1)])) |
|
|
| tokenizer = MagicMock() |
| tokenizer.eos_token_id = 1 |
| tokenizer.encode = MagicMock(return_value=torch.tensor([[2, 3, 4]])) |
| tokenizer.decode = MagicMock(side_effect=lambda ids, **kw: "" if ids.item() == 1 else "A") |
|
|
| result = decoder.generate(model, tokenizer, "What is 47 + 28?") |
| print("Generated:", result.text) |
| print(f"Tokens: {result.tokens_generated} Checks: {result.energy_checks} " |
| f"Mean penalty: {result.mean_penalty:.3f} Latency: {result.latency_seconds*1000:.1f}ms") |
|
|