File size: 5,293 Bytes
8d73145 0e93feb 8d73145 0e93feb f36c5fb 8d73145 0e93feb 8d73145 5dbef48 8d73145 0e93feb 8d73145 0e93feb 8d73145 0e93feb 5dbef48 0e93feb 5dbef48 8d73145 0e93feb 8d73145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
from .configuration_stacked import ImpressoConfig
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
config_class = ImpressoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
print("Current folder path:", os.path.dirname(os.path.abspath(__file__)))
# Get the directory of the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the full path to label_map.json
label_map_path = os.path.join(current_dir, "label_map.json")
label_map = json.load(open(label_map_path, "r"))
self.num_token_labels_dict = get_info(label_map)
self.config = config
import pdb
pdb.set_trace()
self.bert = AutoModel.from_pretrained(
config.pretrained_config["_name_or_path"], config=config.pretrained_config
)
if "classifier_dropout" not in config.__dict__:
classifier_dropout = 0.1
else:
classifier_dropout = (
config.classifier_dropout
if config.classifier_dropout is not None
else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# Additional transformer layers
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_size, nhead=config.num_attention_heads
),
num_layers=2,
)
# For token classification, create a classifier for each task
self.token_classifiers = nn.ModuleDict(
{
task: nn.Linear(config.hidden_size, num_labels)
for task, num_labels in self.num_token_labels_dict.items()
}
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
token_labels: Optional[dict] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
Labels for computing the token classification loss. Keys should match the tasks.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
bert_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
"head_mask": head_mask,
"inputs_embeds": inputs_embeds,
"output_attentions": output_attentions,
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
}
if any(
keyword in self.config.name_or_path.lower()
for keyword in ["llama", "deberta"]
):
bert_kwargs.pop("token_type_ids")
bert_kwargs.pop("head_mask")
outputs = self.bert(**bert_kwargs)
# For token classification
token_output = outputs[0]
token_output = self.dropout(token_output)
# Pass through additional transformer layers
token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
0, 1
)
# Collect the logits and compute the loss for each task
task_logits = {}
total_loss = 0
for task, classifier in self.token_classifiers.items():
logits = classifier(token_output)
task_logits[task] = logits
if token_labels and task in token_labels:
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.num_token_labels_dict[task]),
token_labels[task].view(-1),
)
total_loss += loss
if not return_dict:
output = (task_logits,) + outputs[2:]
return ((total_loss,) + output) if total_loss != 0 else output
return TokenClassifierOutput(
loss=total_loss,
logits=task_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|