File size: 4,073 Bytes
2fb63bf
6f4449d
1d4c041
3d1d657
6f4449d
 
 
 
 
2fb63bf
 
 
 
 
 
6f4449d
 
 
 
 
 
 
2fb63bf
 
 
 
 
 
4537742
2fb63bf
6f4449d
2fb63bf
 
4537742
 
2fb63bf
 
4537742
 
2fb63bf
 
4537742
2fb63bf
6f4449d
2fb63bf
 
4537742
 
2fb63bf
4537742
3d1d657
2fb63bf
3d1d657
 
 
 
 
 
 
 
 
 
2fb63bf
4537742
2fb63bf
 
 
 
 
6f4449d
2fb63bf
6f4449d
 
2fb63bf
 
 
4537742
2fb63bf
4537742
2fb63bf
3d1d657
 
 
4537742
 
2fb63bf
 
 
8f33f3e
 
2fb63bf
1d4c041
2fb63bf
1d4c041
 
 
2fb63bf
 
 
 
3d1d657
 
 
2fb63bf
1d4c041
2fb63bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from collections import defaultdict
from dataclasses import dataclass, field
import time
from typing import Callable, Protocol, Self

@dataclass
class Expansion:
    token: int
    cost: float

@dataclass
class Series:
    id: int
    tokens: list[int]
    budget: float
    expansions: list[Expansion] = field(default_factory=list)

    def get_all_tokens(self) -> list[int]:
        return self.tokens + [e.token for e in self.expansions]

    def get_remaining_budget(self) -> float:
        return self.budget + sum(e.cost for e in self.expansions)

@dataclass
class Batch:
    items: list[Series]

@dataclass
class TokenCandidates:
    series: Series
    expansions: list[Expansion]

@dataclass
class BatchCandidates:
    items: list[TokenCandidates]

# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class BatchExpander(Protocol):
    def expand(self, batch: Batch) -> BatchCandidates: ...

@dataclass
class CompletedSequence:
    series: Series
    expansions: list[list[Expansion]]

@dataclass
class CompletedBatch:
    items: list[CompletedSequence]

def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
    new_series_batch = []
    for expansion in result.expansions:
        if not stopping_criterion(result.series, expansion):
            new_series = Series(
                id=result.series.id,
                tokens=result.series.tokens,
                expansions=result.series.expansions + [expansion],
                budget=result.series.budget
            )
            new_series_batch.append(new_series)
    completed_series = [result.series] if len(new_series_batch) == 0 else []
    return new_series_batch, completed_series

def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
    # check that ids in original_series are unique
    assert len(original_series) == len({s.id for s in original_series})
    # group original series by id
    original_series_by_id = {s.id: s for s in original_series}
    # group expanded series by id
    expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
    for s in expanded_series:
        if len(s.expansions) != 0:
            expanded_series_by_id[s.id].append(s.expansions)
    results = []
    for id, s in original_series_by_id.items():
        expansions = expanded_series_by_id[id]
        expansion_result = CompletedSequence(series=s, expansions=expansions)
        results.append(expansion_result)
    return CompletedBatch(items=results)

def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
    return series.get_remaining_budget() + expansion.cost < 0

# A compound operation that we can implement generically, relying on a BatchExpander
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
    completed_series: list[Series] = []
    current_batch = batch
    while len(current_batch.items) > 0:
        # print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
        print(f"Expanding {len(current_batch.items)} series")
        current_batch_items = []
        start_time = time.time()
        expanded = expander.expand(current_batch)
        print(f"Expanded, took {time.time() - start_time} seconds")
        print("Computing new batch")
        start_time = time.time()
        for item in expanded.items:
            if len(item.expansions) == 0:
                completed_series.append(item.series)
            else:
                new_series, completed = compute_new_series(item, completion_criterion)
                completed_series.extend(completed)
                current_batch_items.extend(new_series)
        current_batch = Batch(items=current_batch_items)
        print(f"Computed, took {time.time() - start_time} seconds")
    return compute_expansions(batch.items, completed_series)