File size: 327 Bytes
d09a13e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import bitsandbytes as bnb
import torch
p = torch.nn.Parameter(torch.rand(10,10).cuda())
a = torch.rand(10,10).cuda()
p1 = p.data.sum().item()
adam = bnb.optim.Adam([p])
out = a*p
loss = out.sum()
loss.backward()
adam.step()
p2 = p.data.sum().item()
assert p1 != p2
print('SUCCESS!')
print('Installation was successful!') |