File size: 327 Bytes
d09a13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import bitsandbytes as bnb
import torch

p = torch.nn.Parameter(torch.rand(10,10).cuda())
a = torch.rand(10,10).cuda()

p1 = p.data.sum().item()

adam = bnb.optim.Adam([p])

out = a*p
loss = out.sum()
loss.backward()
adam.step()

p2 = p.data.sum().item()

assert p1 != p2
print('SUCCESS!')
print('Installation was successful!')