PolyFormer / tasks /base_task.py
jiang
init commit
650c5f6
raw
history blame
11 kB
# ------------------------------------------------------------------------
# Modified from OFA (https://github.com/OFA-Sys/OFA)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
# ------------------------------------------------------------------------
# Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
import logging
import os
import math
import torch
from typing import Dict, Optional
from fairseq import search
from fairseq.data import FairseqDataset, iterators, Dictionary
from fairseq.optim.amp_optimizer import AMPOptimizer
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import DictConfig
from torch import Tensor, device, dtype, nn
logger = logging.getLogger(__name__)
def load_bert_pretrained_weights(model, ckpt_path):
try:
state_dict = torch.load(ckpt_path, map_location="cpu")
except Exception:
raise OSError(
"Unable to load weights from pytorch checkpoint file. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
missing_keys = []
unexpected_keys = []
error_msgs = []
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
##############################################################################################
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = "bert."
load(model, prefix=start_prefix)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {ckpt_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {ckpt_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {ckpt_path}.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
@dataclass
class BaseConfig(FairseqDataclass):
data: Optional[str] = field(
default=None,
metadata={
"help": "comma separated path to data list, will be iterated upon during epochs "
"in round-robin manner; valid data are always in the last"
},
)
selected_cols: Optional[str] = field(
default=None,
metadata={"help": "selected cols"},
)
bpe_dir: Optional[str] = field(
default=None,
metadata={"help": "bpe dir"},
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
max_src_length: int = field(
default=128, metadata={"help": "the maximum src sequence length"}
)
max_tgt_length: int = field(
default=30, metadata={"help": "the maximum target sequence length"}
)
code_dict_size: int = field(
default=8192, metadata={"help": "code dict size"}
)
patch_image_size: int = field(
default=480, metadata={"help": "patch image size"}
)
num_bins: int = field(
default=1000, metadata={"help": "number of quantization bins"}
)
imagenet_default_mean_and_std: bool = field(
default=False,
metadata={"help": "imagenet normalize"},
)
constraint_range: Optional[str] = field(
default=None,
metadata={"help": "constraint range"}
)
@register_task("base_task", dataclass=BaseConfig)
class BaseTask(FairseqTask):
def __init__(self, cfg: BaseConfig, src_dict, tgt_dict):
super().__init__(cfg)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, cfg: DictConfig, **kwargs):
"""Setup the task."""
# Define dictionaries
src_dict = Dictionary()
tgt_dict = Dictionary()
# Add 2D bin tokens
for i in range(cfg.num_bins):
for j in range(cfg.num_bins):
src_dict.add_symbol("<bin_{}_{}>".format(i, j))
tgt_dict.add_symbol("<bin_{}_{}>".format(i, j))
logger.info("source dictionary: {} types".format(len(src_dict)))
logger.info("target dictionary: {} types".format(len(tgt_dict)))
return cls(cfg, src_dict, tgt_dict)
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# create mini-batches with given size constraints
batch_sampler = [
[j for j in range(i, min(i + max_sentences, len(dataset)))]
for i in range(0, len(dataset), max_sentences)
]
total_row_count = dataset.dataset.get_total_row_count()
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
if len(batch_sampler) < num_batches:
batch_sampler.append([])
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=1,
shard_id=0,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size
)
return epoch_iter
def build_model(self, cfg: FairseqDataclass):
model = super().build_model(cfg)
bpe_dict = {
"_name": "gpt2",
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
}
bpe_dict = DictConfig(bpe_dict)
self.bpe = self.build_bpe(bpe_dict)
return model
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs
):
"""
Do forward and backward, and return the loss as computed by *criterion*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
loss, sample_size, logging_output = criterion(model, sample, update_num=update_num)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.src_dict
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict