File size: 4,664 Bytes
45311fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
# @Time    : 2023/02/18 02:07 p.m.
# @Author  : JianingWang
# @File    : parameter_freeze.py

import torch


"""
This is use for parameter fixing and unfreezing, which can be viewed as parameter-efficient settings.
"""
class ParameterFreeze():
    # freeze all parameters
    def freeze_lm(self, model: torch.nn.Module):
        for name, param in model.named_parameters():
            param.requires_grad = False
        return model

    # freeze all parameters without cls / mlm head
    def freeze_lm_encoder(self, model: torch.nn.Module):
        for name, param in model.named_parameters():
            if "lm_head" in name or ("cls" in name):
                print(name)
                continue
            param.requires_grad = False
        return model

    # freeze all parameters without bias
    def freeze_lm_finetune_bias(self, model: torch.nn.Module):
        for name, param in model.named_parameters():
            if "bias" in name:
                print(name)
                continue
            param.requires_grad = False
        return model

    # freeze the component that user defined
    def freeze_lm_component(self, model: torch.nn.Module, component: str):
        if "attention" in component:
            for name, param in model.named_parameters():
                if "attention" in name:
                    if "output" in component:
                        if "output" in name:
                            continue
                    else:
                        continue
                param.requires_grad = False
            model = self.unfreeze_classification_head(model)
        elif "feedforward" in component:
            for name, param in model.named_parameters():
                if "dense" in name and "attention" not in name:
                    if "output" in component:
                        if "output" in name:
                            continue
                    else:
                        if "intermediate" in component:
                            if "intermediate" in name:
                                continue
                param.requires_grad = False
            model = self.unfreeze_classification_head(model)
        elif component == "adapter":
            for name, param in model.named_parameters():
                if "adapter" in name:
                    continue

                param.requires_grad = False
            model = self.unfreeze_classification_head(model)
        elif "embedding" in component:
            for name, param in model.named_parameters():
                if "embedding" in name:
                    continue

                param.requires_grad = False
            model = self.unfreeze_classification_head(model)
        elif "bias" in component:
            for name, param in model.named_parameters():
                if "bias" in name:
                    continue
                param.requires_grad = False
            model = self.unfreeze_classification_head(model)
        elif "head" in component:
            for name, param in model.named_parameters():
                param.requires_grad = False
            model = self.unfreeze_classification_head(model)

        elif "prompt_emb" in component:
            for name, param in model.named_parameters():
                if "prompt_emb" in name:
                    continue
                param.requires_grad = False
        return model

    # unfreeze cls head
    def unfreeze_classification_head(self, model: torch.nn.Module):
        for name, param in model.named_parameters():
            if "lm_head" in name or ("cls" in name) or ("classifier" in name):
                param.requires_grad = True
        return model

    # freeze k layers
    def freeze_lm_k_layers(self, model: torch.nn.Module, k):
        keep_layers = []
        update_parameters = []
        for i in range(k):
            keep_layers.append("layer."+str(23-i))

        for name, param in model.named_parameters():
            update = False
            for layer_num in keep_layers:
                if layer_num in name:
                    if "dense" in name and "attention" not in name:
                        if "output" in name:
                            print(name)
                            update_parameters.append(name)
                            update = True

            if not update:
                param.requires_grad = False
        model = self.unfreeze_classification_head(model)
        return model


    def unfreeze_lm(self, model: torch.nn.Module):
        for param in model.parameters():
            param.requires_grad = True
        return model