PicoAudio2 / models /common.py
rookie9's picture
Upload 77 files
f582ec6 verified
from pathlib import Path
import torch
import torch.nn as nn
from utils.torch_utilities import load_pretrained_model, merge_matched_keys
import warnings
class LoadPretrainedBase(nn.Module):
def process_state_dict(
self, model_dict: dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor]
):
"""
Custom processing functions of each model that transforms `state_dict` loaded from
checkpoints to the state that can be used in `load_state_dict`.
Use `merge_mathced_keys` to update parameters with matched names and shapes by
default.
Args
model_dict:
The state dict of the current model, which is going to load pretrained parameters
state_dict:
A dictionary of parameters from a pre-trained model.
Returns:
dict[str, torch.Tensor]:
The updated state dict, where parameters with matched keys and shape are
updated with values in `state_dict`.
"""
state_dict = merge_matched_keys(model_dict, state_dict)
return state_dict
def load_pretrained(self, ckpt_path: str | Path):
load_pretrained_model(
self, ckpt_path, state_dict_process_fn=self.process_state_dict
)
class CountParamsBase(nn.Module):
def count_params(self):
num_params = 0
trainable_params = 0
for param in self.parameters():
num_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return num_params, trainable_params
class SaveTrainableParamsBase(nn.Module):
@property
def param_names_to_save(self):
names = []
for name, param in self.named_parameters():
if param.requires_grad:
names.append(name)
for name, _ in self.named_buffers():
names.append(name)
return names
def load_state_dict(self, state_dict, strict=True, assign=True):
print("State dict keys:", list(state_dict.keys()))
#for key in self.param_names_to_save:
# if key not in state_dict:
# raise Exception(
# f"{key} not found in either pre-trained models (e.g. BERT)"
# " or resumed checkpoints (e.g. epoch_40/model.pt)"
# )
# 兼容 PyTorch/transformers 的 assign 参数
return super().load_state_dict(state_dict, strict=strict, assign=assign)