# Copyright 2022 The OFA-Sys Team. # All rights reserved. # This source code is licensed under the Apache 2.0 license # found in the LICENSE file in the root directory. from dataclasses import dataclass, field import logging import os import math import torch from typing import Dict, Optional from fairseq import search from fairseq.data import FairseqDataset, iterators from fairseq.optim.amp_optimizer import AMPOptimizer from fairseq.dataclass import FairseqDataclass from fairseq.tasks import FairseqTask, register_task from omegaconf import DictConfig from data.audio_utils import AUDIO_CFG logger = logging.getLogger(__name__) @dataclass class OFAConfig(FairseqDataclass): data: Optional[str] = field( default=None, metadata={ "help": "comma separated path to data list, will be iterated upon during epochs " "in round-robin manner; valid data are always in the last" }, ) selected_cols: Optional[str] = field( default=None, metadata={"help": "selected cols"}, ) bpe: Optional[str] = field( default='gpt2', metadata={"help": "which bpe to use"}, ) bpe_dir: Optional[str] = field( default=None, metadata={"help": "bpe dir"}, ) max_source_positions: int = field( default=1024, metadata={"help": "max number of tokens in the source sequence"} ) max_target_positions: int = field( default=1024, metadata={"help": "max number of tokens in the target sequence"} ) max_src_length: int = field( default=128, metadata={"help": "the maximum src sequence length"} ) max_tgt_length: int = field( default=30, metadata={"help": "the maximum target sequence length"} ) code_dict_size: int = field( default=8192, metadata={"help": "code dict size"} ) patch_image_size: int = field( default=480, metadata={"help": "patch image size"} ) orig_patch_image_size: int = field( default=256, metadata={"help": "patch image size"} ) num_bins: int = field( default=1000, metadata={"help": "number of quantization bins"} ) imagenet_default_mean_and_std: bool = field( default=False, metadata={"help": "imagenet normalize"}, ) constraint_range: Optional[str] = field( default=None, metadata={"help": "constraint range"} ) ## data read_from_img_path: bool = field( default=False, metadata={"help": "read from image paths, don't convert images to str"}, ) image_dir: str = field( default="/lus/scratch/NAT/gda2204/SHARED/data", metadata={"help": "path to datasets']"}, ) no_image_transform: bool = field( default=False, metadata={"help": "read from image paths, don't convert images to str"}, ) patch_frame_size: int = field( default=224, metadata={"help": "patch_frame_size"} ) num_frames: int = field( default=4, metadata={"help": "patch_frame_size"} ) # audio max_audio_len: int = field( default=480000, metadata={"help": "patchmax_audio_len_frame_size"} ) audio_cfg: Optional[str] = field( default=None, metadata={"help": '{"beam":5,"unnormalized":true,"temperature":1.0}'} ) sample_type: str = field( default="rand", metadata={"help": "sample_type']"}, ) sample_rate: int = field( default=48000, metadata={"help": "sample rate"} ) audio_sample_rate: bool = field( default=False, metadata={"help": "sample rate"}, ) ast: bool = field( default=False, metadata={"help": "Use AST audio encoder"}, ) mel_bins: int = field( default=64, metadata={"help": "mel_bins"} ) use_dataaug: bool = field( default=False, metadata={"help": "Use data augmentation"}, ) @register_task("unival", dataclass=OFAConfig) class OFATask(FairseqTask): def __init__(self, cfg: OFAConfig, src_dict, tgt_dict): super().__init__(cfg) self.src_dict = src_dict self.tgt_dict = tgt_dict @classmethod def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task.""" # load dictionaries src_dict = cls.load_dictionary( os.path.join(cfg.bpe_dir, "dict.txt") ) tgt_dict = cls.load_dictionary( os.path.join(cfg.bpe_dir, "dict.txt") ) src_dict.add_symbol("") tgt_dict.add_symbol("") for i in range(cfg.code_dict_size): src_dict.add_symbol("".format(i)) tgt_dict.add_symbol("".format(i)) # quantization for i in range(cfg.num_bins): src_dict.add_symbol("".format(i)) tgt_dict.add_symbol("".format(i)) logger.info("source dictionary: {} types".format(len(src_dict))) logger.info("target dictionary: {} types".format(len(tgt_dict))) return cls(cfg, src_dict, tgt_dict) def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, data_buffer_size=0, disable_iterator_cache=False, ): assert isinstance(dataset, FairseqDataset) # initialize the dataset with the correct starting epoch dataset.set_epoch(epoch) # create mini-batches with given size constraints batch_sampler = [ [j for j in range(i, min(i + max_sentences, len(dataset)))] for i in range(0, len(dataset), max_sentences) ] total_row_count = dataset.dataset.get_total_row_count() num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) if len(batch_sampler) < num_batches: batch_sampler.append([]) # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=1, shard_id=0, num_workers=num_workers, epoch=epoch, buffer_size=data_buffer_size ) return epoch_iter def build_model(self, cfg: FairseqDataclass): model = super().build_model(cfg) if self.cfg.bpe == 'bert': bpe_dict = { "_name": "bert", "bpe_vocab_file": os.path.join(self.cfg.bpe_dir, "vocab.txt"), "bpe_cased": False } bpe_dict = DictConfig(bpe_dict) self.bpe = self.build_bpe(bpe_dict) else: bpe_dict = { "_name": "gpt2", "gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"), "gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe") } bpe_dict = DictConfig(bpe_dict) self.bpe = self.build_bpe(bpe_dict) return model def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, ): """ Build a :class:`~fairseq.SequenceGenerator` instance for this task. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models args (fairseq.dataclass.configs.GenerationConfig): configuration object (dataclass) for generation extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass through to SequenceGenerator prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): If provided, this function constrains the beam search to allowed tokens only at each step. The provided function should take 2 arguments: the batch ID (`batch_id: int`) and a unidimensional tensor of token ids (`inputs_ids: torch.Tensor`). It has to return a `List[int]` with the allowed tokens for the next generation step conditioned on the previously generated tokens (`inputs_ids`) and the batch ID (`batch_id`). This argument is useful for constrained generation conditioned on the prefix, as described in "Autoregressive Entity Retrieval" (https://arxiv.org/abs/2010.00904) and https://github.com/facebookresearch/GENRE. """ if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), ) from fairseq.sequence_generator import ( # SequenceGenerator, SequenceGeneratorWithAlignment, ) from models.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) stop_on_max_len = getattr(args, "stop_on_max_len", False) if prefix_allowed_tokens_fn is None: prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if ( sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ] ) > 1 ): raise ValueError("Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling( self.target_dictionary, sampling_topk, sampling_topp ) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch( self.target_dictionary, diverse_beam_groups, diverse_beam_strength ) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate ) elif constrained: search_strategy = search.LexicallyConstrainedBeamSearch( self.target_dictionary, args.constraints ) elif prefix_allowed_tokens_fn: search_strategy = search.PrefixConstrainedBeamSearch( self.target_dictionary, prefix_allowed_tokens_fn ) else: search_strategy = search.BeamSearch(self.target_dictionary) extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment extra_gen_cls_kwargs["print_alignment"] = args.print_alignment else: seq_gen_cls = SequenceGenerator return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, constraint_range=self.cfg.constraint_range, stop_on_max_len=getattr(args, "stop_on_max_len", False), **extra_gen_cls_kwargs, ) def train_step( self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs ): """ Do forward and backward, and return the loss as computed by *criterion* for the given *model* and *sample*. Args: sample (dict): the mini-batch. The format is defined by the :class:`~fairseq.data.FairseqDataset`. model (~fairseq.models.BaseFairseqModel): the model criterion (~fairseq.criterions.FairseqCriterion): the criterion optimizer (~fairseq.optim.FairseqOptimizer): the optimizer update_num (int): the current update ignore_grad (bool): multiply loss by 0 if this is set to True Returns: tuple: - the loss - the sample size, which is used as the denominator for the gradient - logging outputs to display while training """ model.train() model.set_num_updates(update_num) with torch.autograd.profiler.record_function("forward"): with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): loss, sample_size, logging_output = criterion(model, sample, update_num=update_num) if ignore_grad: loss *= 0 with torch.autograd.profiler.record_function("backward"): optimizer.backward(loss) return loss, sample_size, logging_output def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def source_dictionary(self): """Return the source :class:`~fairseq.data.Dictionary`.""" return self.src_dict @property def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary`.""" return self.tgt_dict