dev-slx's picture
Upload 9 files
9513141 verified
raw
history blame
No virus
695 Bytes
# Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
from prettytable import PrettyTable
def count_parameters(model):
"""Count the number of parameters in the model."""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
params = parameter.numel()
table.add_row([name, params])
total_params+=params
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
def batchify(lst, n):
"""Divide a list into chunks of size n."""
return [lst[i:i + n] for i in range(0, len(lst), n)]