File size: 1,993 Bytes
583c1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# self-supervised learning, one of the embedding acts as the target, the other as the support
# works nicely

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
from tqdm import tqdm
import random

class Transform(nn.Module):
    def __init__(self, n=2, token_size=32, input_dim=2048):
        super().__init__()
        
        self.n=n
        self.token_size=token_size
        
        self.weight = nn.Parameter(torch.ones(self.n,self.token_size),requires_grad=True)
        
    def encode(self, x):
        x = torch.einsum('bij,bi->ij', x, self.weight)
        return x
    
    def forward(self, x):
        x = self.encode(x)
        return x
    
def criterion(output, target, token_sample_rate=0.25):
    t=target-output
    t=torch.norm(t,dim=1)
    s=random.sample(range(t.shape[0]),int(token_sample_rate*t.shape[0]))
    return torch.mean(t[s])

def online_train(cond, device="cuda:1",step=1000):
    old_device=cond.device
    dtype=cond.dtype
    cond = cond.clone().to(device,torch.float32)
    # cond.requires_grad=False
    # torch.set_grad_enabled(True)
    
    y=cond[0,:,:]
    cond=cond[1:,:,:]
    
    print("online training, initializing model...")
    n=cond.shape[0]
    model=Transform(n=n)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)
    model.to(device)
    model.train()
    
    random.seed(42)
    bar=tqdm(range(step))
    for s in bar:
        optimizer.zero_grad()
        x=cond
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        bar.set_postfix(loss=loss.item())
        
    weight=model.weight
    print(weight)
    cond=weight[:,:,None]*cond+y[None,:,:]*(1.0/n)
    
    print("online training, ending...")
    del model
    del optimizer
    
    cond=torch.mean(cond,dim=0).unsqueeze(0)
    return cond.to(old_device,dtype=dtype)