|
import logging |
|
import re |
|
from typing import * |
|
|
|
import torch |
|
from allennlp.common.from_params import Params, T |
|
from allennlp.training.optimizers import Optimizer |
|
|
|
logger = logging.getLogger('optim') |
|
|
|
|
|
@Optimizer.register('transformer') |
|
class TransformerOptimizer: |
|
""" |
|
Wrapper for AllenNLP optimizer. |
|
This is used to fine-tune the pretrained transformer with some layers fixed and different learning rate. |
|
When some layers are fixed, the wrapper will set the `require_grad` flag as False, which could save |
|
training time and optimize memory usage. |
|
Plz contact Guanghui Qin for bugs. |
|
Params: |
|
base: base optimizer. |
|
embeddings_lr: learning rate for embedding layer. Set as 0.0 to fix it. |
|
encoder_lr: learning rate for encoder layer. Set as 0.0 to fix it. |
|
pooler_lr: learning rate for pooler layer. Set as 0.0 to fix it. |
|
layer_fix: the number of encoder layers that should be fixed. |
|
|
|
Example json config: |
|
|
|
1. No-op. Do nothing (why do you use me?) |
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: 0.001 |
|
} |
|
} |
|
|
|
2. Fix everything in the transformer. |
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: 0.001 |
|
}, |
|
embeddings_lr: 0.0, |
|
encoder_lr: 0.0, |
|
pooler_lr: 0.0 |
|
} |
|
|
|
Or equivalently (suppose we have 24 layers) |
|
|
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: 0.001 |
|
}, |
|
embeddings_lr: 0.0, |
|
layer_fix: 24, |
|
pooler_lr: 0.0 |
|
} |
|
|
|
3. Fix embeddings and the lower 12 encoder layers, set a small learning rate |
|
for the other parts of the transformer |
|
|
|
optimizer: { |
|
type: "transformer", |
|
base: { |
|
type: "adam", |
|
lr: 0.001 |
|
}, |
|
embeddings_lr: 0.0, |
|
layer_fix: 12, |
|
encoder_lr: 1e-5, |
|
pooler_lr: 1e-5 |
|
} |
|
""" |
|
@classmethod |
|
def from_params( |
|
cls: Type[T], |
|
params: Params, |
|
model_parameters: List[Tuple[str, torch.nn.Parameter]], |
|
**_ |
|
): |
|
param_groups = list() |
|
|
|
def remove_param(keyword_): |
|
nonlocal model_parameters |
|
logger.info(f'Fix param with name matching {keyword_}.') |
|
for name, param in model_parameters: |
|
if keyword_ in name: |
|
logger.debug(f'Fix param {name}.') |
|
param.requires_grad_(False) |
|
model_parameters = list(filter(lambda x: keyword_ not in x[0], model_parameters)) |
|
|
|
for i_layer in range(params.pop('layer_fix')): |
|
remove_param('transformer_model.encoder.layer.{}.'.format(i_layer)) |
|
|
|
for specific_lr, keyword in ( |
|
(params.pop('embeddings_lr', None), 'transformer_model.embeddings'), |
|
(params.pop('encoder_lr', None), 'transformer_model.encoder.layer'), |
|
(params.pop('pooler_lr', None), 'transformer_model.pooler'), |
|
): |
|
if specific_lr is not None: |
|
if specific_lr > 0.: |
|
pattern = '.*' + keyword.replace('.', r'\.') + '.*' |
|
if len([name for name, _ in model_parameters if re.match(pattern, name)]) > 0: |
|
param_groups.append([[pattern], {'lr': specific_lr}]) |
|
else: |
|
logger.warning(f'{pattern} is set to use lr {specific_lr} but no param matches.') |
|
else: |
|
remove_param(keyword) |
|
|
|
if 'parameter_groups' in params: |
|
for pg in params.pop('parameter_groups'): |
|
param_groups.append([pg[0], pg[1].as_dict()]) |
|
|
|
return Optimizer.by_name(params.get('base').pop('type'))( |
|
model_parameters=model_parameters, parameter_groups=param_groups, |
|
**params.pop('base').as_flat_dict() |
|
) |
|
|