''' Initial Code taken from SemSup Repository. ''' import torch from torch import nn import sys from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification # Import configs from transformers.models.roberta.configuration_roberta import RobertaConfig from transformers.models.bert.configuration_bert import BertConfig import numpy as np # Loss functions from torch.nn import BCEWithLogitsLoss from typing import Optional, Union, Tuple, Dict, List import itertools MODEL_FOR_SEMANTIC_EMBEDDING = { "roberta": "RobertaForSemanticEmbedding", "bert": "BertForSemanticEmbedding", } MODEL_TO_CONFIG = { "roberta": RobertaConfig, "bert": BertConfig, } def getLabelModel(data_args, model_args): tokenizer = AutoTokenizer.from_pretrained( model_args.label_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) model = AutoModel.from_pretrained( model_args.label_model_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) return model, tokenizer class AutoModelForMultiLabelClassification: """ Class for choosing the right model class automatically. Loosely based on AutoModel classes in HuggingFace. """ @staticmethod def from_pretrained(*args, **kwargs): # Check what type of model it is for key in MODEL_TO_CONFIG.keys(): if type(kwargs['config']) == MODEL_TO_CONFIG[key]: class_name = getattr(sys.modules[__name__], MODEL_FOR_SEMANTIC_EMBEDDING[key]) return class_name.from_pretrained(*args, **kwargs) # If none of the models were chosen raise("This model type is not supported. Please choose one of {}".format(MODEL_FOR_SEMANTIC_EMBEDDING.keys())) from transformers import BertForSequenceClassification, BertTokenizer from transformers import RobertaForSequenceClassification, RobertaTokenizer from transformers import XLNetForSequenceClassification, XLNetTokenizer class BertForSemanticEmbedding(nn.Module): def __init__(self, config): # super().__init__(config) super().__init__() self.config = config self.coil = config.coil if self.coil: assert config.arch_type == 2 self.token_dim = config.token_dim try: # Try catch was added to handle the ongoing hyper search experiments. self.arch_type = config.arch_type except: self.arch_type = 2 try: self.colbert = config.colbert except: self.colbert = False if config.encoder_model_type == 'bert': # self.encoder = BertModel(config) if self.arch_type == 1: self.encoder = AutoModelForSequenceClassification.from_pretrained( 'bert-base-uncased', output_hidden_states = True) else: self.encoder = AutoModel.from_pretrained( config.model_name_or_path ) # self.encoder = AutoModelForSequenceClassification.from_pretrained( # 'bert-base-uncased', output_hidden_states = True).bert elif config.encoder_model_type == 'roberta': self.encoder = RobertaForSequenceClassification.from_pretrained( 'roberta-base', num_labels = config.num_labels, output_hidden_states = True) elif config.encoder_model_type == 'xlnet': self.encoder = XLNetForSequenceClassification.from_pretrained( 'xlnet-base-cased', num_labels = config.num_labels, output_hidden_states = True) print('Config is', config) if config.negative_sampling == 'none': if config.arch_type == 1: self.fc1 = nn.Linear(5 * config.hidden_size, 512 if config.semsup else config.num_labels) elif self.arch_type == 3: self.fc1 = nn.Linear(config.hidden_size, 256 if config.semsup else config.num_labels) if self.coil: self.tok_proj = nn.Linear(self.encoder.config.hidden_size, self.token_dim) self.dropout = nn.Dropout(0.1) self.candidates_topk = 10 if config.negative_sampling != 'none': self.group_y = np.array([np.array([l for l in group]) for group in config.group_y]) #np.load('datasets/EUR-Lex/label_group_lightxml_0.npy', allow_pickle=True) self.negative_sampling = config.negative_sampling self.min_positive_samples = 20 self.semsup = config.semsup self.label_projection = None if self.semsup:# and config.hidden_size != config.label_hidden_size: if self.arch_type == 1: self.label_projection = nn.Linear(512, config.label_hidden_size, bias= False) elif self.arch_type == 2: self.label_projection = nn.Linear(self.encoder.config.hidden_size, config.label_hidden_size, bias= False) elif self.arch_type == 3: self.label_projection = nn.Linear(256, config.label_hidden_size, bias= False) # self.post_init() def compute_tok_score_cart(self, doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask): if not self.colbert: qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1 doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD exact_match = exact_match.float() scores_no_masking = torch.matmul( qry_reps.view(-1, self.token_dim), # (Q * LQ) * d doc_reps.view(-1, self.token_dim).transpose(0, 1) # d * (D * LD) ) scores_no_masking = scores_no_masking.view( *qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD if self.colbert: scores, _ = scores_no_masking.max(dim=3) else: scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1) return tok_scores def coil_eval_forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, desc_input_ids = None, desc_attention_mask = None, lab_reps = None, label_embeddings = None, clustered_input_ids = None, clustered_desc_ids = None, ): outputs_doc, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) doc_reps = self.tok_proj(outputs_doc.last_hidden_state) # D * LD * d # lab_reps = self.tok_proj(outputs_lab.last_hidden_state @ self.label_projection.weight) # Q * LQ * d if clustered_input_ids is None: tok_scores = self.compute_tok_score_cart( doc_reps, input_ids, lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask ) else: tok_scores = self.compute_tok_score_cart( doc_reps, clustered_input_ids, lab_reps, clustered_desc_ids.reshape(-1, clustered_desc_ids.shape[-1]), desc_attention_mask ) logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) new_tok_scores = torch.zeros(logits.shape, device = logits.device) for i in range(tok_scores.shape[1]): stride = tok_scores.shape[0]//tok_scores.shape[1] new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] logits += new_tok_scores.contiguous() return logits def coil_forward(self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, desc_input_ids: Optional[List[int]] = None, desc_attention_mask: Optional[List[int]] = None, desc_inputs_embeds: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, clustered_input_ids = None, clustered_desc_ids = None, ignore_label_embeddings_and_out_lab = None, ): # print(desc_input_ids.shape, desc_attention_mask.shape, desc_inputs_embeds.shape) outputs_doc, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) if ignore_label_embeddings_and_out_lab is not None: outputs_lab, label_embeddings = outputs_lab, label_embeddings else: outputs_lab, label_embeddings, _, _ = self.forward_label_embeddings(None, None, desc_input_ids = desc_input_ids, desc_attention_mask = desc_attention_mask, return_hidden_states = True, desc_inputs_embeds = desc_inputs_embeds) doc_reps = self.tok_proj(outputs_doc.last_hidden_state) # D * LD * d lab_reps = self.tok_proj(outputs_lab.last_hidden_state @ self.label_projection.weight) # Q * LQ * d if clustered_input_ids is None: tok_scores = self.compute_tok_score_cart( doc_reps, input_ids, lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask ) else: tok_scores = self.compute_tok_score_cart( doc_reps, clustered_input_ids, lab_reps, clustered_desc_ids.reshape(-1, clustered_desc_ids.shape[-1]), desc_attention_mask ) logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) new_tok_scores = torch.zeros(logits.shape, device = logits.device) for i in range(tok_scores.shape[1]): stride = tok_scores.shape[0]//tok_scores.shape[1] new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] logits += new_tok_scores.contiguous() loss_fn = BCEWithLogitsLoss() loss = loss_fn(logits, labels) if not return_dict: output = (logits,) + outputs_doc[2:] + (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs_doc.hidden_states, attentions=outputs_doc.attentions, ) def semsup_forward(self, input_embeddings, label_embeddings, num_candidates = -1, list_to_set_mapping = None, same_labels = False): ''' If same_labels = True, directly apply matrix multiplication else: num_candidates must not be -1, list_to_set_mapping must not be None ''' if same_labels: logits = torch.bmm(input_embeddings.unsqueeze(1), label_embeddings.transpose(2,1)).squeeze(1) else: # TODO: Can we optimize this? Perhaps torch.bmm? logits = torch.stack( # For each batch point, calculate corresponding product with label embeddings [ logit @ label_embeddings[list_to_set_mapping[i*num_candidates: (i+1) * num_candidates]].T for i,logit in enumerate(input_embeddings) ] ) return logits def forward_label_embeddings(self, all_candidate_labels, label_desc_ids, desc_input_ids = None, desc_attention_mask = None, desc_inputs_embeds = None, return_hidden_states = False): # Given the candidates, and corresponding # description numbers of labels # Returns the embeddings for unique label descriptions if desc_attention_mask is None: num_candidates = all_candidate_labels.shape[1] # Create a set to perform minimal number of operations on common labels label_desc_ids_list = list(zip(itertools.chain(*label_desc_ids.detach().cpu().tolist()), itertools.chain(*all_candidate_labels.detach().cpu().tolist()))) print('Original Length: ', len(label_desc_ids_list)) label_desc_ids_set = torch.tensor(list(set(label_desc_ids_list))) print('New Length: ', label_desc_ids_set.shape) m1 = {tuple(x):i for i, x in enumerate(label_desc_ids_set.tolist())} list_to_set_mapping = torch.tensor([m1[x] for x in label_desc_ids_list]) descs = [ self.tokenizedDescriptions[self.config.id2label[desc_lb[1].item()]][desc_lb[0]] for desc_lb in label_desc_ids_set ] label_input_ids = torch.cat([ desc['input_ids'] for desc in descs ]) label_attention_mask = torch.cat([ desc['attention_mask'] for desc in descs ]) label_token_type_ids = torch.cat([ desc['token_type_ids'] for desc in descs ]) label_input_ids = label_input_ids.to(label_desc_ids.device) label_attention_mask = label_attention_mask.to(label_desc_ids.device) label_token_type_ids = label_token_type_ids.to(label_desc_ids.device) label_embeddings = self.label_model( label_input_ids, attention_mask=label_attention_mask, token_type_ids=label_token_type_ids, ).pooler_output else: list_to_set_mapping = None num_candidates = None if desc_inputs_embeds is not None: outputs = self.label_model( inputs_embeds = desc_inputs_embeds.reshape(desc_inputs_embeds.shape[0] * desc_inputs_embeds.shape[1],desc_inputs_embeds.shape[2], desc_inputs_embeds.shape[3]).contiguous(), attention_mask=desc_attention_mask.reshape(-1, desc_input_ids.shape[-1]).contiguous(), ) else: outputs = self.label_model( desc_input_ids.reshape(-1, desc_input_ids.shape[-1]).contiguous(), attention_mask=desc_attention_mask.reshape(-1, desc_input_ids.shape[-1]).contiguous(), ) label_embeddings = outputs.pooler_output if self.label_projection is not None: if return_hidden_states: return outputs, label_embeddings @ self.label_projection.weight, list_to_set_mapping, num_candidates else: return label_embeddings @ self.label_projection.weight, list_to_set_mapping, num_candidates else: return label_embeddings, list_to_set_mapping, num_candidates def forward_input_encoder(self, input_ids, attention_mask, token_type_ids, ): outputs = self.encoder( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True if self.arch_type == 1 else False, ) # Currently, method specified in LightXML is used if self.arch_type in [2,3]: logits = outputs[1] elif self.arch_type == 1: logits = torch.cat([outputs.hidden_states[-i][:, 0] for i in range(1, 5+1)], dim=-1) if self.arch_type in [1,3]: logits = self.dropout(logits) # No Sampling if self.arch_type in [1,3]: logits = self.fc1(logits) return outputs, logits def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cluster_labels: Optional[torch.Tensor] = None, all_candidate_labels: Optional[torch.Tensor] = None, label_desc_ids: Optional[List[int]] = None, desc_inputs_embeds : Optional[torch.Tensor] = None, desc_input_ids: Optional[List[int]] = None, desc_attention_mask: Optional[List[int]] = None, label_embeddings : Optional[torch.Tensor] = None, clustered_input_ids: Optional[torch.Tensor] = None, clustered_desc_ids: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: if self.coil: return self.coil_forward( input_ids, attention_mask, token_type_ids, labels, desc_input_ids, desc_attention_mask, desc_inputs_embeds, return_dict, clustered_input_ids, clustered_desc_ids, ) # STEP 2: Forward pass through the input model outputs, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) if self.semsup: if desc_input_ids is None: all_candidate_labels = torch.arange(labels.shape[1]).repeat((labels.shape[0], 1)) label_embeddings, list_to_set_mapping, num_candidates = self.forward_label_embeddings(all_candidate_labels, label_desc_ids) logits = self.semsup_forward(logits, label_embeddings, num_candidates, list_to_set_mapping) else: label_embeddings, _, _ = self.forward_label_embeddings(None, None, desc_input_ids = desc_input_ids, desc_attention_mask = desc_attention_mask, desc_inputs_embeds = desc_inputs_embeds) logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) elif label_embeddings is not None: logits = self.semsup_forward(logits, label_embeddings.contiguous() @ self.label_projection.weight, same_labels= True) loss_fn = BCEWithLogitsLoss() loss = loss_fn(logits, labels) if not return_dict: output = (logits,) + outputs[2:] + (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )