File size: 2,210 Bytes
9d21d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch


class MomentumOptim:
    def __init__(self, step_size=0.01, momentum=0.9):
        self.step_size = step_size
        self.momentum = momentum
        self.m = None  # velocity

    def init(self):
        self.m = None

    def upd_m(self, old_m, g):
        return g + self.momentum * old_m

    def upd(self, old_x, m):
        return old_x + self.step_size * m

    def __call__(self, old_xs, new_xs):
        pesudo_gs = [new_x - old_x for old_x, new_x in zip(old_xs, new_xs)]

        if not self.m:
            self.m = pesudo_gs
        else:
            self.m = [self.upd_m(old_m, g) for old_m, g in zip(self.m, pesudo_gs)]

        updated_kv = [self.upd(old_x, m) for old_x, m in zip(old_xs, self.m)]
        return updated_kv


class AttnOptimWrapper:
    def __init__(self, llm, model_type, optimizer="momentum", **optimizer_args):
        self.model = llm
        self.kv = None
        self.model_type = model_type

        if optimizer == "momentum":
            self.optim_k = MomentumOptim(**optimizer_args)
            self.optim_v = MomentumOptim(**optimizer_args)
        else:
            raise ValueError()

    def init(self):
        self.optim_k.init()
        self.optim_v.init()

    @torch.no_grad()
    def step(self, ctx_ids):
        L = len(ctx_ids)

        ctx_ids = ctx_ids.unsqueeze(0)  # [1, L]
        mask = torch.ones_like(ctx_ids)
        if self.kv is not None:
            mask = mask.repeat(1, 2)  # [1, 2*L]

        next_kv = self.model(
            input_ids=ctx_ids,
            attention_mask=mask,
            past_key_values=self.kv,
            use_cache=True,
        ).past_key_values  # kv @ (old_ctx + new_ctx)

        cur_kv = []
        for layer_k, layer_v in next_kv:
            # [B, num_head, 2*L, head_hidden]
            cur_kv.append([layer_k[:, :, -L:, :], layer_v[:, :, -L:, :]])  # kv @ (new_ctx)

        if not self.kv:
            self.kv = cur_kv
        else:
            old_ks, old_vs = zip(*self.kv)
            cur_ks, cur_vs = zip(*cur_kv)

            upd_ks = self.optim_k(old_ks, cur_ks)
            upd_vs = self.optim_v(old_vs, cur_vs)
            self.kv = list(zip(upd_ks, upd_vs))

        return self.kv