|
import copy |
|
import re |
|
import torch |
|
import util |
|
|
|
class FineTunedModel(torch.nn.Module): |
|
|
|
def __init__(self, |
|
model, |
|
modules, |
|
frozen_modules=[] |
|
): |
|
|
|
super().__init__() |
|
|
|
if isinstance(modules, str): |
|
modules = [modules] |
|
|
|
self.model = model |
|
self.ft_modules = {} |
|
self.orig_modules = {} |
|
|
|
util.freeze(self.model) |
|
|
|
for module_name, module in model.named_modules(): |
|
for ft_module_regex in modules: |
|
|
|
match = re.search(ft_module_regex, module_name) |
|
|
|
if match is not None: |
|
|
|
ft_module = copy.deepcopy(module) |
|
|
|
self.orig_modules[module_name] = module |
|
self.ft_modules[module_name] = ft_module |
|
|
|
util.unfreeze(ft_module) |
|
|
|
print(f"=> Finetuning {module_name}") |
|
|
|
for ft_module_name, module in ft_module.named_modules(): |
|
|
|
ft_module_name = f"{module_name}.{ft_module_name}" |
|
|
|
for freeze_module_name in frozen_modules: |
|
|
|
match = re.search(freeze_module_name, ft_module_name) |
|
|
|
if match: |
|
print(f"=> Freezing {ft_module_name}") |
|
util.freeze(module) |
|
|
|
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values()) |
|
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values()) |
|
|
|
|
|
@classmethod |
|
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]): |
|
|
|
if isinstance(checkpoint, str): |
|
checkpoint = torch.load(checkpoint) |
|
|
|
modules = [f"{key}$" for key in list(checkpoint.keys())] |
|
|
|
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules) |
|
ftm.load_state_dict(checkpoint) |
|
|
|
return ftm |
|
|
|
|
|
def __enter__(self): |
|
|
|
for key, ft_module in self.ft_modules.items(): |
|
util.set_module(self.model, key, ft_module) |
|
|
|
def __exit__(self, exc_type, exc_value, tb): |
|
|
|
for key, module in self.orig_modules.items(): |
|
util.set_module(self.model, key, module) |
|
|
|
def parameters(self): |
|
|
|
parameters = [] |
|
|
|
for ft_module in self.ft_modules.values(): |
|
|
|
parameters.extend(list(ft_module.parameters())) |
|
|
|
return parameters |
|
|
|
def state_dict(self): |
|
|
|
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()} |
|
|
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
for key, sd in state_dict.items(): |
|
|
|
self.ft_modules[key].load_state_dict(sd) |