# self-supervised learning, one of the embedding acts as the target, the other as the support # works nicely import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch from tqdm import tqdm import random class Transform(nn.Module): def __init__(self, n=2, token_size=32, input_dim=2048): super().__init__() self.n=n self.token_size=token_size self.weight = nn.Parameter(torch.ones(self.n,self.token_size),requires_grad=True) def encode(self, x): x = torch.einsum('bij,bi->ij', x, self.weight) return x def forward(self, x): x = self.encode(x) return x def criterion(output, target, token_sample_rate=0.25): t=target-output t=torch.norm(t,dim=1) s=random.sample(range(t.shape[0]),int(token_sample_rate*t.shape[0])) return torch.mean(t[s]) def online_train(cond, device="cuda:1",step=1000): old_device=cond.device dtype=cond.dtype cond = cond.clone().to(device,torch.float32) # cond.requires_grad=False # torch.set_grad_enabled(True) y=cond[0,:,:] cond=cond[1:,:,:] print("online training, initializing model...") n=cond.shape[0] model=Transform(n=n) optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001) model.to(device) model.train() random.seed(42) bar=tqdm(range(step)) for s in bar: optimizer.zero_grad() x=cond output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() bar.set_postfix(loss=loss.item()) weight=model.weight print(weight) cond=weight[:,:,None]*cond+y[None,:,:]*(1.0/n) print("online training, ending...") del model del optimizer cond=torch.mean(cond,dim=0).unsqueeze(0) return cond.to(old_device,dtype=dtype)