RobertaMultiHeadCRF-C32-3 / modeling_multiheadcrf.py
T-Almeida's picture
Upload model
3ea34af verified
raw history blame
No virus
21.1 kB
import os
from typing import Optional, Union, List
from transformers import AutoModel, PreTrainedModel, AutoConfig, AutoModel, RobertaModel, BertModel
from transformers.modeling_outputs import TokenClassifierOutput
from torch import nn
from torch.nn import CrossEntropyLoss
import torch
from itertools import islice
from.configuration_multiheadcrf import MultiHeadCRFConfig
NUM_PER_LAYER = 16
class RobertaMultiHeadCRFModel(PreTrainedModel):
config_class = MultiHeadCRFConfig
transformers_backbone_name = "roberta"
transformers_backbone_class = RobertaModel
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.number_of_layer_per_head = config.number_of_layer_per_head
self.heads = config.classes #expected an array of classes we are predicting
# this can be BERT ROBERTA and other BERT-variants
# THIS IS BC HF needs to have "roberta" for roberta models and "bert" for BERT models as var so tha I can load
# check https://github.com/huggingface/transformers/blob/b487096b02307cd6e0f132b676cdcc7255fe8e74/src/transformers/models/roberta/modeling_roberta.py#L1170C16-L1170C20
setattr(self, self.transformers_backbone_name, self.transformers_backbone_class(config, add_pooling_layer=False))
#self.roberta = self.transformer_backbone_class(config, add_pooling_layer=False)
#AutoModel(config, add_pooling_layer=False)
#AutoModel.from_pretrained(config._name_or_path, config=config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
print(sorted(self.heads))
for ent in self.heads:
for i in range(self.number_of_layer_per_head):
setattr(self, f"{ent}_dense_{i}", nn.Linear(config.hidden_size, config.hidden_size))
setattr(self, f"{ent}_dense_activation_{i}", nn.GELU(approximate='none'))
setattr(self, f"{ent}_classifier", nn.Linear(config.hidden_size, config.num_labels))
setattr(self, f"{ent}_crf", CRF(num_tags=config.num_labels, batch_first=True))
setattr(self, f"{ent}_reduction", config.crf_reduction)
self.reduction=config.crf_reduction
if self.config.freeze == True:
self.manage_freezing()
def training_mode(self):
# for some reason these layers are not being correctly init
# probably related with the lifecycle of the hf .from_pretrained method
for ent in self.heads:
for i in range(self.number_of_layer_per_head):
getattr(self, f"{ent}_dense_{i}").reset_parameters()
getattr(self, f"{ent}_classifier").reset_parameters()
getattr(self, f"{ent}_crf").reset_parameters()
getattr(self, f"{ent}_crf").mask_impossible_transitions()
def manage_freezing(self):
for _, param in getattr(self, self.transformers_backbone_name).embeddings.named_parameters():
param.requires_grad = False
num_encoders_to_freeze = self.config.num_frozen_encoder
if num_encoders_to_freeze > 0:
for _, param in islice(getattr(self, self.transformers_backbone_name).encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER):
param.requires_grad = False
def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
# Default `model.config.use_return_dict´ is `True´
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = getattr(self, self.transformers_backbone_name)(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output) # B S E
logits = {k:0 for k in self.heads}
for ent in self.heads:
for i in range(self.number_of_layer_per_head):
dense_output = getattr(self, f"{ent}_dense_{i}")(sequence_output)
dense_output = getattr(self, f"{ent}_dense_activation_{i}")(dense_output)
logits[ent] = getattr(self, f"{ent}_classifier")(dense_output)
#logits = self.classifier(sequence_output)
loss = None
if labels is not None:
# During train/test as we don't pass labels during inference
# loss
outputs = {k:0 for k in self.heads}
for ent in self.heads:
outputs[ent] = getattr(self, f"{ent}_crf")(logits[ent],labels[ent], reduction=self.reduction)
# print(outputs)
return sum(outputs.values()), logits
else: #running prediction?
# decoded tags
# NOTE: This gather operation (multiGPU) not work here, bc it uses tensors that are on CPU...
outputs = {k:0 for k in self.heads}
for ent in self.heads:
outputs[ent] = torch.Tensor(getattr(self, f"{ent}_crf").decode(logits[ent]))
return [outputs[ent] for ent in sorted(self.heads)]
class BertMultiHeadCRFModel(RobertaMultiHeadCRFModel):
config_class = MultiHeadCRFConfig
transformers_backbone_name = "bert"
transformers_backbone_class = BertModel
_keys_to_ignore_on_load_unexpected = [r"pooler"]
# Taken from https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py and fixed got uint8 warning
LARGE_NEGATIVE_NUMBER = -1e9
class CRF(nn.Module):
"""Conditional random field.
This module implements a conditional random field [LMP01]_. The forward computation
of this class computes the log likelihood of the given sequence of tags and
emission score tensor. This class also has `~CRF.decode` method which finds
the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
Args:
num_tags: Number of tags.
batch_first: Whether the first dimension corresponds to the size of a minibatch.
Attributes:
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
``(num_tags,)``.
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
``(num_tags,)``.
transitions (`~torch.nn.Parameter`): Transition score tensor of size
``(num_tags, num_tags)``.
.. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
"Conditional random fields: Probabilistic models for segmenting and
labeling sequence data". *Proc. 18th International Conf. on Machine
Learning*. Morgan Kaufmann. pp. 282–289.
.. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
"""
def __init__(self, num_tags: int, batch_first: bool = False) -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
self.num_tags = num_tags
self.batch_first = batch_first
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
self.reset_parameters()
self.mask_impossible_transitions()
def reset_parameters(self) -> None:
"""Initialize the transition parameters.
The parameters will be initialized randomly from a uniform distribution
between -0.1 and 0.1.
"""
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
nn.init.uniform_(self.transitions, -0.1, 0.1)
def mask_impossible_transitions(self) -> None:
"""Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER
- start transition value of I-X
- transition score of O -> I
"""
with torch.no_grad():
self.start_transitions[2] = LARGE_NEGATIVE_NUMBER
self.transitions[0][2] = LARGE_NEGATIVE_NUMBER
def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
def forward(
self,
emissions: torch.Tensor,
tags: torch.LongTensor,
mask: Optional[torch.ByteTensor] = None,
reduction: str = 'sum',
) -> torch.Tensor:
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
tags (`~torch.LongTensor`): Sequence of tags tensor of size
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
reduction: Specifies the reduction to apply to the output:
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
``sum``: the output will be summed over batches. ``mean``: the output will be
averaged over batches. ``token_mean``: the output will be averaged over tokens.
Returns:
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
reduction is ``none``, ``()`` otherwise.
"""
#self.mask_impossible_transitions()
self._validate(emissions, tags=tags, mask=mask)
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
if mask is None:
mask = torch.ones_like(tags, dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
# shape: (batch_size,)
numerator = self._compute_score(emissions, tags, mask)
# shape: (batch_size,)
denominator = self._compute_normalizer(emissions, mask)
# shape: (batch_size,)
llh = numerator - denominator
nllh = -llh
if reduction == 'none':
return nllh
if reduction == 'sum':
return nllh.sum()
if reduction == 'mean':
return nllh.mean()
assert reduction == 'token_mean'
return nllh.sum() / mask.type_as(emissions).sum()
def decode(self, emissions: torch.Tensor,
mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
"""Find the most likely tag sequence using Viterbi algorithm.
Args:
emissions (`~torch.Tensor`): Emission score tensor of size
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
``(batch_size, seq_length, num_tags)`` otherwise.
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
Returns:
List of list containing the best tag sequence for each batch.
"""
self._validate(emissions, mask=mask)
if mask is None:
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
if self.batch_first:
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
return self._viterbi_decode(emissions, mask)
def _validate(
self,
emissions: torch.Tensor,
tags: Optional[torch.LongTensor] = None,
mask: Optional[torch.ByteTensor] = None) -> None:
if emissions.dim() != 3:
raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
if emissions.size(2) != self.num_tags:
raise ValueError(
f'expected last dimension of emissions is {self.num_tags}, '
f'got {emissions.size(2)}')
if tags is not None:
if emissions.shape[:2] != tags.shape:
raise ValueError(
'the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
if mask is not None:
if emissions.shape[:2] != mask.shape:
raise ValueError(
'the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
no_empty_seq = not self.batch_first and mask[0].all()
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
if not no_empty_seq and not no_empty_seq_bf:
raise ValueError('mask of the first timestep must all be on')
def _compute_score(
self, emissions: torch.Tensor, tags: torch.LongTensor,
mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (seq_length, batch_size, num_tags)
# tags: (seq_length, batch_size)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and tags.dim() == 2
assert emissions.shape[:2] == tags.shape
assert emissions.size(2) == self.num_tags
assert mask.shape == tags.shape
assert mask[0].all()
seq_length, batch_size = tags.shape
mask = mask.type_as(emissions)
# Start transition score and first emission
# shape: (batch_size,)
score = self.start_transitions[tags[0]]
score += emissions[0, torch.arange(batch_size), tags[0]]
for i in range(1, seq_length):
# Transition score to next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += self.transitions[tags[i - 1], tags[i]] * mask[i]
# Emission score for next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
# End transition score
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
# shape: (batch_size,)
last_tags = tags[seq_ends, torch.arange(batch_size)]
# shape: (batch_size,)
score += self.end_transitions[last_tags]
return score
def _compute_normalizer(
self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and mask.dim() == 2
assert emissions.shape[:2] == mask.shape
assert emissions.size(2) == self.num_tags
assert mask[0].all()
seq_length = emissions.size(0)
# Start transition score and first emission; score has size of
# (batch_size, num_tags) where for each batch, the j-th column stores
# the score that the first timestep has tag j
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
for i in range(1, seq_length):
# Broadcast score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the sum of scores of all
# possible tag sequences so far that end with transitioning from tag i to tag j
# and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Sum over all possible current tags, but we're in score space, so a sum
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
# all possible tag sequences so far, that end in tag i
# shape: (batch_size, num_tags)
next_score = torch.logsumexp(next_score, dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Sum (log-sum-exp) over all possible tags
# shape: (batch_size,)
return torch.logsumexp(score, dim=1)
def _viterbi_decode(self, emissions: torch.FloatTensor,
mask: torch.ByteTensor) -> List[List[int]]:
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and mask.dim() == 2
assert emissions.shape[:2] == mask.shape
assert emissions.size(2) == self.num_tags
assert mask[0].all()
seq_length, batch_size = mask.shape
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[0]
history = []
# score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
# with tag j
# history saves where the best tags candidate transitioned from; this is used
# when we trace back the best tag sequence
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
# for every possible next tag
for i in range(1, seq_length):
# Broadcast viterbi score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emission = emissions[i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the score of the best
# tag sequence so far that ends with transitioning from tag i to tag j and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emission
# Find the maximum score over all possible current tag
# shape: (batch_size, num_tags)
next_score, indices = next_score.max(dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# and save the index that produces the next score
# shape: (batch_size, num_tags)
score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
history.append(indices)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Now, compute the best path for each sample
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=0) - 1
best_tags_list = []
for idx in range(batch_size):
# Find the tag which maximizes the score at the last timestep; this is our best tag
# for the last timestep
_, best_last_tag = score[idx].max(dim=0)
best_tags = [best_last_tag.item()]
# We trace back where the best last tag comes from, append that to our best tag
# sequence, and trace it back again, and so on
for hist in reversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(best_last_tag.item())
# Reverse the order because we start from the last timestep
best_tags.reverse()
best_tags_list.append(best_tags)
return best_tags_list