File size: 4,073 Bytes
2fb63bf 6f4449d 1d4c041 3d1d657 6f4449d 2fb63bf 6f4449d 2fb63bf 4537742 2fb63bf 6f4449d 2fb63bf 4537742 2fb63bf 4537742 2fb63bf 4537742 2fb63bf 6f4449d 2fb63bf 4537742 2fb63bf 4537742 3d1d657 2fb63bf 3d1d657 2fb63bf 4537742 2fb63bf 6f4449d 2fb63bf 6f4449d 2fb63bf 4537742 2fb63bf 4537742 2fb63bf 3d1d657 4537742 2fb63bf 8f33f3e 2fb63bf 1d4c041 2fb63bf 1d4c041 2fb63bf 3d1d657 2fb63bf 1d4c041 2fb63bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from collections import defaultdict
from dataclasses import dataclass, field
import time
from typing import Callable, Protocol, Self
@dataclass
class Expansion:
token: int
cost: float
@dataclass
class Series:
id: int
tokens: list[int]
budget: float
expansions: list[Expansion] = field(default_factory=list)
def get_all_tokens(self) -> list[int]:
return self.tokens + [e.token for e in self.expansions]
def get_remaining_budget(self) -> float:
return self.budget + sum(e.cost for e in self.expansions)
@dataclass
class Batch:
items: list[Series]
@dataclass
class TokenCandidates:
series: Series
expansions: list[Expansion]
@dataclass
class BatchCandidates:
items: list[TokenCandidates]
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class BatchExpander(Protocol):
def expand(self, batch: Batch) -> BatchCandidates: ...
@dataclass
class CompletedSequence:
series: Series
expansions: list[list[Expansion]]
@dataclass
class CompletedBatch:
items: list[CompletedSequence]
def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
new_series_batch = []
for expansion in result.expansions:
if not stopping_criterion(result.series, expansion):
new_series = Series(
id=result.series.id,
tokens=result.series.tokens,
expansions=result.series.expansions + [expansion],
budget=result.series.budget
)
new_series_batch.append(new_series)
completed_series = [result.series] if len(new_series_batch) == 0 else []
return new_series_batch, completed_series
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
# check that ids in original_series are unique
assert len(original_series) == len({s.id for s in original_series})
# group original series by id
original_series_by_id = {s.id: s for s in original_series}
# group expanded series by id
expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
for s in expanded_series:
if len(s.expansions) != 0:
expanded_series_by_id[s.id].append(s.expansions)
results = []
for id, s in original_series_by_id.items():
expansions = expanded_series_by_id[id]
expansion_result = CompletedSequence(series=s, expansions=expansions)
results.append(expansion_result)
return CompletedBatch(items=results)
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
return series.get_remaining_budget() + expansion.cost < 0
# A compound operation that we can implement generically, relying on a BatchExpander
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
completed_series: list[Series] = []
current_batch = batch
while len(current_batch.items) > 0:
# print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
print(f"Expanding {len(current_batch.items)} series")
current_batch_items = []
start_time = time.time()
expanded = expander.expand(current_batch)
print(f"Expanded, took {time.time() - start_time} seconds")
print("Computing new batch")
start_time = time.time()
for item in expanded.items:
if len(item.expansions) == 0:
completed_series.append(item.series)
else:
new_series, completed = compute_new_series(item, completion_criterion)
completed_series.extend(completed)
current_batch_items.extend(new_series)
current_batch = Batch(items=current_batch_items)
print(f"Computed, took {time.time() - start_time} seconds")
return compute_expansions(batch.items, completed_series)
|