#! python3 # -*- encoding: utf-8 -*- from transformers.models.ernie.modeling_ernie import * import torch.utils.checkpoint from torch import nn from transformers.utils import logging import inspect from typing import Set, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import re import math from typing import Optional, Tuple import torch from fairseq import utils from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise from torch import Tensor, nn from torch.hub import load_state_dict_from_url import torch.distributed as dist logger = logging.get_logger(__name__) from torch.hub import load_state_dict_from_url import torch.distributed as dist PRETRAINED_MODEL_URLS = { "pcqm4mv1_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv1.pt", "pcqm4mv2_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv2.pt", "oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", # this pretrained model is temporarily unavailable "pcqm4mv1_graphormer_base_for_molhiv":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_base_preln_pcqm4mv1_for_hiv.pt", } def load_pretrained_model(pretrained_model_name): if pretrained_model_name not in PRETRAINED_MODEL_URLS: raise ValueError("Unknown pretrained model name %s", pretrained_model_name) if not dist.is_initialized(): return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"] else: pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"] dist.barrier() return pretrained_model class MultiheadAttention(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__( self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, self_attention=False, q_noise=0.0, qn_block_size=8, ): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim ), "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5 self.self_attention = self_attention assert self.self_attention, "Only support self attention" assert not self.self_attention or self.qkv_same_dim, ( "Self-attention requires query, key and " "value to be of the same size" ) self.k_proj = quant_noise( nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.v_proj = quant_noise( nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size ) self.q_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) self.out_proj = quant_noise( nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size ) self.reset_parameters() self.onnx_trace = False def prepare_for_onnx_export_(self): raise NotImplementedError def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with # the scaled initialization nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) else: nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0) def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], attn_bias: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time x Batch x Channel Args: key_padding_mask (ByteTensor, optional): mask to exclude keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. need_weights (bool, optional): return the attention weights, averaged over heads (default: False). attn_mask (ByteTensor, optional): typically used to implement causal attention, where the mask prevents the attention from looking forward in time (default: None). before_softmax (bool, optional): return the raw attention weights and values before the attention softmax. need_head_weights (bool, optional): return the attention weights for each head. Implies *need_weights*. Default: return the average attention weights over all heads. """ if need_head_weights: need_weights = True tgt_len, bsz, embed_dim = query.size() src_len = tgt_len assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" assert list(query.size()) == [tgt_len, bsz, embed_dim] if key is not None: src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): assert key_bsz == bsz assert value is not None assert src_len, bsz == value.shape[:2] q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) q *= self.scaling q = ( q.contiguous() .view(tgt_len, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.num_heads, self.head_dim) .transpose(0, 1) ) assert k is not None assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_bias is not None: attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace ) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None attn = torch.bmm(attn_probs, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights def upgrade_state_dict_named(self, state_dict, name): prefix = name + "." if name != "" else "" items_to_add = {} keys_to_remove = [] for k in state_dict.keys(): if k.endswith(prefix + "in_proj_weight"): # in_proj_weight used to be q + k + v with same dimensions dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] keys_to_remove.append(k) k_bias = prefix + "in_proj_bias" if k_bias in state_dict.keys(): dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ dim : 2 * dim ] items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] keys_to_remove.append(prefix + "in_proj_bias") for k in keys_to_remove: del state_dict[k] for key, value in items_to_add.items(): state_dict[key] = value def init_graphormer_params(module): """ Initialize the weights specific to the Graphormer Model. """ def normal_(data): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) if module.bias is not None: module.bias.data.zero_() if isinstance(module, nn.Embedding): normal_(module.weight.data) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, MultiheadAttention): normal_(module.q_proj.weight.data) normal_(module.k_proj.weight.data) normal_(module.v_proj.weight.data) def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") return fn return docstring_decorator def add_start_docstrings_to_model_forward(*docstr): def docstring_decorator(fn): docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") class_name = f"[`{fn.__qualname__.split('.')[0]}`]" intro = f" The {class_name} forward method, overrides the `__call__` special method." note = r""" Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`] instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them. """ fn.__doc__ = intro + note + docstring return fn return docstring_decorator def add_end_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr) return fn return docstring_decorator PT_RETURN_INTRODUCTION = r""" Returns: [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the configuration ([`{config_class}`]) and inputs. """ TF_RETURN_INTRODUCTION = r""" Returns: [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the configuration ([`{config_class}`]) and inputs. """ def _get_indent(t): """Returns the indentation in the first line of t""" search = re.search(r"^(\s*)\S", t) return "" if search is None else search.groups()[0] def _convert_output_args_doc(output_args_doc): """Convert output_args_doc to display properly.""" # Split output_arg_doc in blocks argument/description indent = _get_indent(output_args_doc) blocks = [] current_block = "" for line in output_args_doc.split("\n"): # If the indent is the same as the beginning, the line is the name of new arg. if _get_indent(line) == indent: if len(current_block) > 0: blocks.append(current_block[:-1]) current_block = f"{line}\n" else: # Otherwise it's part of the description of the current arg. # We need to remove 2 spaces to the indentation. current_block += f"{line[2:]}\n" blocks.append(current_block[:-1]) # Format each block for proper rendering for i in range(len(blocks)): blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i]) blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i]) return "\n".join(blocks) def _prepare_output_docstrings(output_type, config_class, min_indent=None): """ Prepares the return part of the docstring using `output_type`. """ output_docstring = output_type.__doc__ # Remove the head of the docstring to keep the list of args only lines = output_docstring.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None: i += 1 if i < len(lines): params_docstring = "\n".join(lines[(i + 1):]) params_docstring = _convert_output_args_doc(params_docstring) # Add the return introduction full_output_type = f"{output_type.__module__}.{output_type.__name__}" intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION intro = intro.format(full_output_type=full_output_type, config_class=config_class) result = intro + params_docstring # Apply minimum indent if necessary if min_indent is not None: lines = result.split("\n") # Find the indent of the first nonempty line i = 0 while len(lines[i]) == 0: i += 1 indent = len(_get_indent(lines[i])) # If too small, add indentation to all nonempty lines if indent < min_indent: to_add = " " * (min_indent - indent) lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines] result = "\n".join(lines) return result PT_TOKEN_CLASSIFICATION_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer( ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" ... ) >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_token_class_ids = logits.argmax(-1) >>> # Note that tokens are classified rather then input words which means that >>> # there might be more predicted token classes than words. >>> # Multiple token classes might account for the same word >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]] >>> predicted_tokens_classes {expected_output} ``` ```python >>> labels = predicted_token_class_ids >>> loss = model(**inputs, labels=labels).loss >>> round(loss.item(), 2) {expected_loss} ``` """ PT_QUESTION_ANSWERING_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> answer_start_index = outputs.start_logits.argmax() >>> answer_end_index = outputs.end_logits.argmax() >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] >>> tokenizer.decode(predict_answer_tokens) {expected_output} ``` ```python >>> # target is "nice puppet" >>> target_start_index = torch.tensor([{qa_target_start_index}]) >>> target_end_index = torch.tensor([{qa_target_end_index}]) >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) >>> loss = outputs.loss >>> round(loss.item(), 2) {expected_loss} ``` """ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" Example of single-label classification: ```python >>> import torch >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_class_id = logits.argmax().item() >>> model.config.id2label[predicted_class_id] {expected_output} ``` ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) >>> labels = torch.tensor([1]) >>> loss = model(**inputs, labels=labels).loss >>> round(loss.item(), 2) {expected_loss} ``` Example of multi-label classification: ```python >>> import torch >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_class_id = logits.argmax().item() >>> model.config.id2label[predicted_class_id] {expected_output} ``` ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) >>> model = {model_class}.from_pretrained( ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification" ... ) >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to( ... torch.float ... ) >>> loss = model(**inputs, labels=labels).loss >>> loss.backward() # doctest: +IGNORE_RESULT ``` """ PT_MASKED_LM_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> # retrieve index of {mask} >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) >>> tokenizer.decode(predicted_token_id) {expected_output} ``` ```python >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] >>> # mask labels of non-{mask} tokens >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) >>> outputs = model(**inputs, labels=labels) >>> round(outputs.loss.item(), 2) {expected_loss} ``` """ PT_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ``` """ PT_MULTIPLE_CHOICE_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> choice0 = "It is eaten with a fork and a knife." >>> choice1 = "It is eaten while held in the hand." >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1 >>> # the linear classifier still needs to be trained >>> loss = outputs.loss >>> logits = outputs.logits ``` """ PT_CAUSAL_LM_SAMPLE = r""" Example: ```python >>> import torch >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs, labels=inputs["input_ids"]) >>> loss = outputs.loss >>> logits = outputs.logits ``` """ PT_SPEECH_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> from datasets import load_dataset >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) {expected_output} ``` """ PT_SPEECH_CTC_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> import torch >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_ids = torch.argmax(logits, dim=-1) >>> # transcribe speech >>> transcription = processor.batch_decode(predicted_ids) >>> transcription[0] {expected_output} ``` ```python >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids >>> # compute loss >>> loss = model(**inputs).loss >>> round(loss.item(), 2) {expected_loss} ``` """ PT_SPEECH_SEQ_CLASS_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> import torch >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> predicted_class_ids = torch.argmax(logits, dim=-1).item() >>> predicted_label = model.config.id2label[predicted_class_ids] >>> predicted_label {expected_output} ``` ```python >>> # compute loss - target_label is e.g. "down" >>> target_label = model.config.id2label[0] >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]]) >>> loss = model(**inputs).loss >>> round(loss.item(), 2) {expected_loss} ``` """ PT_SPEECH_FRAME_CLASS_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> import torch >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate) >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> probabilities = torch.sigmoid(logits[0]) >>> # labels is a one-hot array of shape (num_frames, num_speakers) >>> labels = (probabilities > 0.5).long() >>> labels[0].tolist() {expected_output} ``` """ PT_SPEECH_XVECTOR_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> import torch >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = feature_extractor( ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True ... ) >>> with torch.no_grad(): ... embeddings = model(**inputs).embeddings >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() >>> # the resulting embeddings can be used for cosine similarity-based retrieval >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1) >>> similarity = cosine_sim(embeddings[0], embeddings[1]) >>> threshold = 0.7 # the optimal threshold is dataset-dependent >>> if similarity < threshold: ... print("Speakers are not the same!") >>> round(similarity.item(), 2) {expected_output} ``` """ PT_VISION_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> from datasets import load_dataset >>> dataset = load_dataset("huggingface/cats-image") >>> image = dataset["test"]["image"][0] >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = feature_extractor(image, return_tensors="pt") >>> with torch.no_grad(): ... outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) {expected_output} ``` """ PT_VISION_SEQ_CLASS_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import torch >>> from datasets import load_dataset >>> dataset = load_dataset("huggingface/cats-image") >>> image = dataset["test"]["image"][0] >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = feature_extractor(image, return_tensors="pt") >>> with torch.no_grad(): ... logits = model(**inputs).logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_label = logits.argmax(-1).item() >>> print(model.config.id2label[predicted_label]) {expected_output} ``` """ PT_SAMPLE_DOCSTRINGS = { "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE, "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE, "MaskedLM": PT_MASKED_LM_SAMPLE, "LMHead": PT_CAUSAL_LM_SAMPLE, "BaseModel": PT_BASE_MODEL_SAMPLE, "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE, "CTC": PT_SPEECH_CTC_SAMPLE, "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE, "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE, "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE, "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE, "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE, } TF_TOKEN_CLASSIFICATION_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer( ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf" ... ) >>> logits = model(**inputs).logits >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1) >>> # Note that tokens are classified rather then input words which means that >>> # there might be more predicted token classes than words. >>> # Multiple token classes might account for the same word >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()] >>> predicted_tokens_classes {expected_output} ``` ```python >>> labels = predicted_token_class_ids >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss) >>> round(float(loss), 2) {expected_loss} ``` """ TF_QUESTION_ANSWERING_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors="tf") >>> outputs = model(**inputs) >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0]) >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0]) >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] >>> tokenizer.decode(predict_answer_tokens) {expected_output} ``` ```python >>> # target is "nice puppet" >>> target_start_index = tf.constant([{qa_target_start_index}]) >>> target_end_index = tf.constant([{qa_target_end_index}]) >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) >>> loss = tf.math.reduce_mean(outputs.loss) >>> round(float(loss), 2) {expected_loss} ``` """ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> logits = model(**inputs).logits >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) >>> model.config.id2label[predicted_class_id] {expected_output} ``` ```python >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> num_labels = len(model.config.id2label) >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) >>> labels = tf.constant(1) >>> loss = model(**inputs, labels=labels).loss >>> round(float(loss), 2) {expected_loss} ``` """ TF_MASKED_LM_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf") >>> logits = model(**inputs).logits >>> # retrieve index of {mask} >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0]) >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index) >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1) >>> tokenizer.decode(predicted_token_id) {expected_output} ``` ```python >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] >>> # mask labels of non-{mask} tokens >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) >>> outputs = model(**inputs, labels=labels) >>> round(float(outputs.loss), 2) {expected_loss} ``` """ TF_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> outputs = model(inputs) >>> last_hidden_states = outputs.last_hidden_state ``` """ TF_MULTIPLE_CHOICE_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> choice0 = "It is eaten with a fork and a knife." >>> choice1 = "It is eaten while held in the hand." >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True) >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} >>> outputs = model(inputs) # batch size is 1 >>> # the linear classifier still needs to be trained >>> logits = outputs.logits ``` """ TF_CAUSAL_LM_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") >>> outputs = model(inputs) >>> logits = outputs.logits ``` """ TF_SPEECH_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) {expected_output} ``` """ TF_SPEECH_CTC_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> import tensorflow as tf >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") >>> dataset = dataset.sort("id") >>> sampling_rate = dataset.features["audio"].sampling_rate >>> processor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> # audio file is decoded on the fly >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") >>> logits = model(**inputs).logits >>> predicted_ids = tf.math.argmax(logits, axis=-1) >>> # transcribe speech >>> transcription = processor.batch_decode(predicted_ids) >>> transcription[0] {expected_output} ``` ```python >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids >>> # compute loss >>> loss = model(**inputs).loss >>> round(float(loss), 2) {expected_loss} ``` """ TF_VISION_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> from datasets import load_dataset >>> dataset = load_dataset("huggingface/cats-image") >>> image = dataset["test"]["image"][0] >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = feature_extractor(image, return_tensors="tf") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state >>> list(last_hidden_states.shape) {expected_output} ``` """ TF_VISION_SEQ_CLASS_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> import tensorflow as tf >>> from datasets import load_dataset >>> dataset = load_dataset("huggingface/cats-image") >>> image = dataset["test"]["image"][0] >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = feature_extractor(image, return_tensors="tf") >>> logits = model(**inputs).logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_label = int(tf.math.argmax(logits, axis=-1)) >>> print(model.config.id2label[predicted_label]) {expected_output} ``` """ TF_SAMPLE_DOCSTRINGS = { "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE, "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE, "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE, "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE, "MaskedLM": TF_MASKED_LM_SAMPLE, "LMHead": TF_CAUSAL_LM_SAMPLE, "BaseModel": TF_BASE_MODEL_SAMPLE, "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE, "CTC": TF_SPEECH_CTC_SAMPLE, "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE, "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE, } FLAX_TOKEN_CLASSIFICATION_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> outputs = model(**inputs) >>> logits = outputs.logits ``` """ FLAX_QUESTION_ANSWERING_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> inputs = tokenizer(question, text, return_tensors="jax") >>> outputs = model(**inputs) >>> start_scores = outputs.start_logits >>> end_scores = outputs.end_logits ``` """ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> outputs = model(**inputs) >>> logits = outputs.logits ``` """ FLAX_MASKED_LM_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax") >>> outputs = model(**inputs) >>> logits = outputs.logits ``` """ FLAX_BASE_MODEL_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ``` """ FLAX_MULTIPLE_CHOICE_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> choice0 = "It is eaten with a fork and a knife." >>> choice1 = "It is eaten while held in the hand." >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True) >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}}) >>> logits = outputs.logits ``` """ FLAX_CAUSAL_LM_SAMPLE = r""" Example: ```python >>> from transformers import {processor_class}, {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") >>> outputs = model(**inputs) >>> # retrieve logts for next token >>> next_token_logits = outputs.logits[:, -1] ``` """ FLAX_SAMPLE_DOCSTRINGS = { "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE, "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE, "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE, "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE, "MaskedLM": FLAX_MASKED_LM_SAMPLE, "BaseModel": FLAX_BASE_MODEL_SAMPLE, "LMHead": FLAX_CAUSAL_LM_SAMPLE, } def add_code_sample_docstrings( *docstr, processor_class=None, checkpoint=None, output_type=None, config_class=None, mask="[MASK]", qa_target_start_index=14, qa_target_end_index=15, model_cls=None, modality=None, expected_output="", expected_loss="", ): def docstring_decorator(fn): # model_class defaults to function's class if not specified otherwise model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls if model_class[:2] == "TF": sample_docstrings = TF_SAMPLE_DOCSTRINGS elif model_class[:4] == "Flax": sample_docstrings = FLAX_SAMPLE_DOCSTRINGS else: sample_docstrings = PT_SAMPLE_DOCSTRINGS # putting all kwargs for docstrings in a dict to be used # with the `.format(**doc_kwargs)`. Note that string might # be formatted with non-existing keys, which is fine. doc_kwargs = dict( model_class=model_class, processor_class=processor_class, checkpoint=checkpoint, mask=mask, qa_target_start_index=qa_target_start_index, qa_target_end_index=qa_target_end_index, expected_output=expected_output, expected_loss=expected_loss, ) if "SequenceClassification" in model_class and modality == "audio": code_sample = sample_docstrings["AudioClassification"] elif "SequenceClassification" in model_class: code_sample = sample_docstrings["SequenceClassification"] elif "QuestionAnswering" in model_class: code_sample = sample_docstrings["QuestionAnswering"] elif "TokenClassification" in model_class: code_sample = sample_docstrings["TokenClassification"] elif "MultipleChoice" in model_class: code_sample = sample_docstrings["MultipleChoice"] elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]: code_sample = sample_docstrings["MaskedLM"] elif "LMHead" in model_class or "CausalLM" in model_class: code_sample = sample_docstrings["LMHead"] elif "CTC" in model_class: code_sample = sample_docstrings["CTC"] elif "AudioFrameClassification" in model_class: code_sample = sample_docstrings["AudioFrameClassification"] elif "XVector" in model_class and modality == "audio": code_sample = sample_docstrings["AudioXVector"] elif "Model" in model_class and modality == "audio": code_sample = sample_docstrings["SpeechBaseModel"] elif "Model" in model_class and modality == "vision": code_sample = sample_docstrings["VisionBaseModel"] elif "Model" in model_class or "Encoder" in model_class: code_sample = sample_docstrings["BaseModel"] elif "ImageClassification" in model_class: code_sample = sample_docstrings["ImageClassification"] else: raise ValueError(f"Docstring can't be built for model {model_class}") func_doc = (fn.__doc__ or "") + "".join(docstr) output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class) built_doc = code_sample.format(**doc_kwargs) fn.__doc__ = func_doc + output_doc + built_doc return fn return docstring_decorator def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear: """ Prune a linear layer to keep only entries in index. Used to remove heads. Args: layer (`torch.nn.Linear`): The layer to prune. index (`torch.LongTensor`): The indices to keep in the layer. dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices. Returns: `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`. """ index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: if dim == 1: b = layer.bias.clone().detach() else: b = layer.bias[index].clone().detach() new_size = list(layer.weight.size()) new_size[dim] = len(index) new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device) new_layer.weight.requires_grad = False new_layer.weight.copy_(W.contiguous()) new_layer.weight.requires_grad = True if layer.bias is not None: new_layer.bias.requires_grad = False new_layer.bias.copy_(b.contiguous()) new_layer.bias.requires_grad = True return new_layer def apply_chunking_to_forward( forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors ) -> torch.Tensor: """ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory. If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly applying `forward_fn` to `input_tensors`. Args: forward_fn (`Callable[..., torch.Tensor]`): The forward function of the model. chunk_size (`int`): The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`. chunk_dim (`int`): The dimension over which the `input_tensors` should be chunked. input_tensors (`Tuple[torch.Tensor]`): The input tensors of `forward_fn` which will be chunked Returns: `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`. Examples: ```python # rename the usual forward() fn to forward_chunk() def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states # implement a chunked forward function def forward(self, hidden_states): return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) ```""" assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors" # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) if num_args_in_forward_chunk_fn != len(input_tensors): raise ValueError( f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input " "tensors are given" ) if chunk_size > 0: tensor_shape = input_tensors[0].shape[chunk_dim] for input_tensor in input_tensors: if input_tensor.shape[chunk_dim] != tensor_shape: raise ValueError( f"All input tenors have to be of the same shape: {tensor_shape}, " f"found shape {input_tensor.shape[chunk_dim]}" ) if input_tensors[0].shape[chunk_dim] % chunk_size != 0: raise ValueError( f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk " f"size {chunk_size}" ) num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size # chunk input tensor into tuples input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors) # apply forward fn to every tuple output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks)) # concatenate output at same dimension return torch.cat(output_chunks, dim=chunk_dim) return forward_fn(*input_tensors) def find_pruneable_heads_and_indices( heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int] ) -> Tuple[Set[int], torch.LongTensor]: """ Finds the heads and their indices taking `already_pruned_heads` into account. Args: heads (`List[int]`): List of the indices of heads to prune. n_heads (`int`): The number of heads in the model. head_size (`int`): The size of each head. already_pruned_heads (`Set[int]`): A set of already pruned heads. Returns: `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices. """ mask = torch.ones(n_heads, head_size) heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads for head in heads: # Compute how many pruned heads are before the head and move the index accordingly head = head - sum(1 if h < head else 0 for h in already_pruned_heads) mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) index: torch.LongTensor = torch.arange(len(mask))[mask].long() return heads, index