gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import torch
from torch import nn
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding = nn.Embedding(180, 128).to(device)
gt = torch.randint(0, 2, (180, 2048)).to(device)
head = nn.Linear(128, 2048).to(device)
optimizer = optim.Adam([embedding.weight, head.weight])
while True:
pred = head(embedding.weight).sigmoid()
loss = nn.MSELoss()(pred, gt.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()