File size: 490 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch import nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

embedding = nn.Embedding(180, 128).to(device)
gt = torch.randint(0, 2, (180, 2048)).to(device)
head = nn.Linear(128, 2048).to(device)
optimizer = optim.Adam([embedding.weight, head.weight])

while True:
    pred = head(embedding.weight).sigmoid()
    loss = nn.MSELoss()(pred, gt.float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()