| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.distributed as dist |
| |
|
| | import transformers |
| | from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead |
| | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, dim): |
| | super(ResidualBlock, self).__init__() |
| | self.fc = nn.Linear(dim, dim) |
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | out = self.fc(x) |
| | out = self.relu(out) |
| | out = out + x |
| | return out |
| |
|
| | class SemanticModel(nn.Module): |
| | def __init__(self, num_layers=2, input_dim=768, hidden_dim=512, output_dim=384): |
| | super(SemanticModel, self).__init__() |
| | |
| | self.layers = nn.ModuleList() |
| | |
| | self.layers.append(nn.Linear(input_dim, hidden_dim)) |
| | |
| | for _ in range(num_layers): |
| | self.layers.append(ResidualBlock(hidden_dim)) |
| |
|
| | self.layers.append(nn.Linear(hidden_dim, output_dim)) |
| |
|
| | def forward(self, x): |
| | for i in range(len(self.layers)): |
| | x = self.layers[i](x) |
| | |
| | return x |
| |
|
| | class Similarity(nn.Module): |
| | """ |
| | Dot product or cosine similarity |
| | """ |
| |
|
| | def __init__(self, temp): |
| | super().__init__() |
| | self.temp = temp |
| | self.cos = nn.CosineSimilarity(dim=-1) |
| |
|
| | def forward(self, x, y): |
| | return self.cos(x, y) / self.temp |
| |
|
| |
|
| | class RobertaClassificationHeadForEmbedding(RobertaClassificationHead): |
| | """Head for sentence-level classification tasks.""" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| | classifier_dropout = ( |
| | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
| |
|
| | def forward(self, features, **kwargs): |
| | x = features[:, 0, :] |
| | x = self.dropout(x) |
| | x = self.dense(x) |
| | |
| | |
| | |
| | return x |
| | |
| | def cl_init(cls, config): |
| | """ |
| | Contrastive learning class init function. |
| | """ |
| | cls.sim = Similarity(temp=cls.model_args.temp) |
| | cls.init_weights() |
| |
|
| | def remove_diagonal_elements(input_tensor): |
| | """ |
| | Removes the diagonal elements from a square matrix (bs, bs) |
| | and returns a new matrix of size (bs, bs-1). |
| | """ |
| | if input_tensor.size(0) != input_tensor.size(1): |
| | raise ValueError("Input tensor must be square (bs, bs).") |
| | |
| | bs = input_tensor.size(0) |
| | mask = ~torch.eye(bs, dtype=torch.bool, device=input_tensor.device) |
| | output_tensor = input_tensor[mask].view(bs, bs - 1) |
| | return output_tensor |
| |
|
| | def cl_forward(cls, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | latter_sentiment_spoof_mask=None, |
| | ): |
| | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict |
| | batch_size = input_ids.size(0) |
| | |
| | |
| | num_sent = input_ids.size(1) |
| |
|
| | |
| | input_ids = input_ids.view((-1, input_ids.size(-1))) |
| | attention_mask = attention_mask.view((-1, attention_mask.size(-1))) |
| | if token_type_ids is not None: |
| | token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) |
| | |
| | |
| | outputs = cls.roberta( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| |
|
| | |
| | sequence_output = outputs[0] |
| | pooler_output = cls.classifier(sequence_output) |
| | pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) |
| | |
| | |
| | pooler_output = cls.map(pooler_output) |
| | |
| | |
| | original = pooler_output[:, 0] |
| | paraphrase_list = [pooler_output[:, i] for i in range(1, cls.model_args.num_paraphrased + 1)] |
| | if cls.model_args.num_negative == 0: |
| | negative_list = [] |
| | else: |
| | negative_list = [pooler_output[:, i] for i in range(cls.model_args.num_paraphrased + 1, cls.model_args.num_paraphrased + cls.model_args.num_negative + 1)] |
| |
|
| | |
| | if dist.is_initialized() and cls.training: |
| | raise NotImplementedError |
| | |
| | |
| | original = torch.tanh(original * 1000) |
| | paraphrase_list = [torch.tanh(p * 1000) for p in paraphrase_list] |
| | negative_list = [torch.tanh(n * 1000) for n in negative_list] |
| | spoofing_cnames = cls.model_args.spoofing_cnames |
| | negative_dict = {} |
| | for cname, n in zip(spoofing_cnames, negative_list): |
| | negative_dict[cname] = n |
| |
|
| | |
| | loss_triplet = 0 |
| | for i in range(batch_size): |
| | for j in range(cls.model_args.num_paraphrased): |
| | for cname in spoofing_cnames: |
| | if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0: |
| | continue |
| | ori = original[i] |
| | pos = paraphrase_list[j][i] |
| | neg = negative_dict[cname][i] |
| | loss_triplet += F.relu(cls.sim(ori, neg) * cls.model_args.temp - cls.sim(ori, pos) * cls.model_args.temp + cls.model_args.margin) |
| | loss_triplet /= (batch_size * cls.model_args.num_paraphrased * len(spoofing_cnames)) |
| |
|
| | |
| | def sign_loss(x): |
| | row = torch.abs(torch.mean(torch.mean(x, dim=0))) |
| | col = torch.abs(torch.mean(torch.mean(x, dim=1))) |
| | return (row + col)/2 |
| |
|
| | loss_gr = sign_loss(original) |
| |
|
| | |
| | loss_3_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] |
| | loss_3_tensor = torch.cat(loss_3_list, dim=1) |
| | loss_3 = loss_3_tensor.mean() * cls.model_args.temp |
| |
|
| | |
| | negative_sample_loss = {} |
| | for cname in spoofing_cnames: |
| | negatives = negative_dict[cname] |
| | originals = original.clone() |
| | if cname == 'latter_sentiment_spoof_0': |
| | negatives = negatives[latter_sentiment_spoof_mask == 1] |
| | originals = originals[latter_sentiment_spoof_mask == 1] |
| | one_negative_loss = cls.sim(originals, negatives).mean() * cls.model_args.temp |
| | negative_sample_loss[cname] = one_negative_loss |
| |
|
| | |
| | ori_ori_cos = cls.sim(original.unsqueeze(1), original.unsqueeze(0)) |
| | ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) |
| | loss_5 = ori_ori_cos_removed.mean() * cls.model_args.temp |
| |
|
| | loss = loss_gr + loss_triplet |
| |
|
| | result = { |
| | 'loss': loss, |
| | 'loss_gr': loss_gr, |
| | 'sim_paraphrase': loss_3, |
| | 'sim_other': loss_5, |
| | 'hidden_states': outputs.hidden_states, |
| | 'attentions': outputs.attentions, |
| | } |
| |
|
| | for cname, l in negative_sample_loss.items(): |
| | key = f"sim_{cname.replace('_spoof_0', '')}" |
| | result[key] = l |
| |
|
| | result['loss_tl'] = loss_triplet |
| |
|
| | if not return_dict: |
| | raise NotImplementedError |
| | |
| | |
| | return result |
| |
|
| |
|
| | def sentemb_forward( |
| | cls, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | ): |
| |
|
| | return_dict = return_dict if return_dict is not None else cls.config.use_return_dict |
| |
|
| | outputs = cls.roberta( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | sequence_output = outputs[0] |
| | pooler_output = cls.classifier(sequence_output) |
| |
|
| | |
| | mapping_output = cls.map(pooler_output) |
| | pooler_output = mapping_output |
| | |
| |
|
| | if not return_dict: |
| | return (outputs[0], pooler_output) + outputs[2:] |
| |
|
| | return BaseModelOutputWithPoolingAndCrossAttentions( |
| | pooler_output=pooler_output, |
| | last_hidden_state=outputs.last_hidden_state, |
| | hidden_states=outputs.hidden_states, |
| | ) |
| |
|
| |
|
| | class RobertaForCL(RobertaForSequenceClassification): |
| | _keys_to_ignore_on_load_missing = [r"position_ids"] |
| |
|
| | def __init__(self, config, *model_args, **model_kargs): |
| | super().__init__(config) |
| | self.model_args = model_kargs.get("model_args", None) |
| |
|
| | self.classifier = RobertaClassificationHeadForEmbedding(config) |
| |
|
| | if self.model_args: |
| | cl_init(self, config) |
| |
|
| | self.map = SemanticModel(input_dim=768) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward(self, |
| | input_ids=None, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | sent_emb=False, |
| | latter_sentiment_spoof_mask=None, |
| | ): |
| | if sent_emb: |
| | return sentemb_forward(self, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | else: |
| | return cl_forward(self, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | latter_sentiment_spoof_mask=latter_sentiment_spoof_mask, |
| | ) |
| |
|
| |
|