File size: 490 Bytes
c1a7f73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
from torch import nn
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding = nn.Embedding(180, 128).to(device)
gt = torch.randint(0, 2, (180, 2048)).to(device)
head = nn.Linear(128, 2048).to(device)
optimizer = optim.Adam([embedding.weight, head.weight])
while True:
pred = head(embedding.weight).sigmoid()
loss = nn.MSELoss()(pred, gt.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
|