Spaces:
Running
Running
File size: 1,993 Bytes
583c1c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# self-supervised learning, one of the embedding acts as the target, the other as the support
# works nicely
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
from tqdm import tqdm
import random
class Transform(nn.Module):
def __init__(self, n=2, token_size=32, input_dim=2048):
super().__init__()
self.n=n
self.token_size=token_size
self.weight = nn.Parameter(torch.ones(self.n,self.token_size),requires_grad=True)
def encode(self, x):
x = torch.einsum('bij,bi->ij', x, self.weight)
return x
def forward(self, x):
x = self.encode(x)
return x
def criterion(output, target, token_sample_rate=0.25):
t=target-output
t=torch.norm(t,dim=1)
s=random.sample(range(t.shape[0]),int(token_sample_rate*t.shape[0]))
return torch.mean(t[s])
def online_train(cond, device="cuda:1",step=1000):
old_device=cond.device
dtype=cond.dtype
cond = cond.clone().to(device,torch.float32)
# cond.requires_grad=False
# torch.set_grad_enabled(True)
y=cond[0,:,:]
cond=cond[1:,:,:]
print("online training, initializing model...")
n=cond.shape[0]
model=Transform(n=n)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)
model.to(device)
model.train()
random.seed(42)
bar=tqdm(range(step))
for s in bar:
optimizer.zero_grad()
x=cond
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
bar.set_postfix(loss=loss.item())
weight=model.weight
print(weight)
cond=weight[:,:,None]*cond+y[None,:,:]*(1.0/n)
print("online training, ending...")
del model
del optimizer
cond=torch.mean(cond,dim=0).unsqueeze(0)
return cond.to(old_device,dtype=dtype) |