lolcats / src /model /utils.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
raw
history blame contribute delete
456 Bytes
import numpy as np
def count_parameters(model, requires_grad: bool = True):
"""
Return total # of trainable parameters
"""
if requires_grad:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
else:
model_parameters = model.parameters()
try:
return sum([np.prod(p.size()) for p in model_parameters]).item()
except:
return sum([np.prod(p.size()) for p in model_parameters])