ca4fc4d
1
2
3
4
5
6
7
8
import torch from Andromeda.configs import Andromeda1Billion model = Andromeda1Billion().cuda() x = torch.randint(0, 256, (1, 1024)).cuda() model(x) # (1, 1024, 20000)