Spaces:
Runtime error
Runtime error
''' | |
Initial Code taken from SemSup Repository. | |
''' | |
import torch | |
from torch import nn | |
import sys | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel | |
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel | |
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification | |
# Import configs | |
from transformers.models.roberta.configuration_roberta import RobertaConfig | |
from transformers.models.bert.configuration_bert import BertConfig | |
import numpy as np | |
# Loss functions | |
from torch.nn import BCEWithLogitsLoss | |
from typing import Optional, Union, Tuple, Dict, List | |
import itertools | |
MODEL_FOR_SEMANTIC_EMBEDDING = { | |
"roberta": "RobertaForSemanticEmbedding", | |
"bert": "BertForSemanticEmbedding", | |
} | |
MODEL_TO_CONFIG = { | |
"roberta": RobertaConfig, | |
"bert": BertConfig, | |
} | |
def getLabelModel(data_args, model_args): | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.label_model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
use_fast=model_args.use_fast_tokenizer, | |
revision=model_args.model_revision, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
model = AutoModel.from_pretrained( | |
model_args.label_model_name_or_path, | |
cache_dir=model_args.cache_dir, | |
revision=model_args.model_revision, | |
use_auth_token=True if model_args.use_auth_token else None, | |
) | |
return model, tokenizer | |
class AutoModelForMultiLabelClassification: | |
""" | |
Class for choosing the right model class automatically. | |
Loosely based on AutoModel classes in HuggingFace. | |
""" | |
def from_pretrained(*args, **kwargs): | |
# Check what type of model it is | |
for key in MODEL_TO_CONFIG.keys(): | |
if type(kwargs['config']) == MODEL_TO_CONFIG[key]: | |
class_name = getattr(sys.modules[__name__], MODEL_FOR_SEMANTIC_EMBEDDING[key]) | |
return class_name.from_pretrained(*args, **kwargs) | |
# If none of the models were chosen | |
raise("This model type is not supported. Please choose one of {}".format(MODEL_FOR_SEMANTIC_EMBEDDING.keys())) | |
from transformers import BertForSequenceClassification, BertTokenizer | |
from transformers import RobertaForSequenceClassification, RobertaTokenizer | |
from transformers import XLNetForSequenceClassification, XLNetTokenizer | |
class BertForSemanticEmbedding(nn.Module): | |
def __init__(self, config): | |
# super().__init__(config) | |
super().__init__() | |
self.config = config | |
self.coil = config.coil | |
if self.coil: | |
assert config.arch_type == 2 | |
self.token_dim = config.token_dim | |
try: # Try catch was added to handle the ongoing hyper search experiments. | |
self.arch_type = config.arch_type | |
except: | |
self.arch_type = 2 | |
try: | |
self.colbert = config.colbert | |
except: | |
self.colbert = False | |
if config.encoder_model_type == 'bert': | |
# self.encoder = BertModel(config) | |
if self.arch_type == 1: | |
self.encoder = AutoModelForSequenceClassification.from_pretrained( | |
'bert-base-uncased', output_hidden_states = True) | |
else: | |
self.encoder = AutoModel.from_pretrained( | |
config.model_name_or_path | |
) | |
# self.encoder = AutoModelForSequenceClassification.from_pretrained( | |
# 'bert-base-uncased', output_hidden_states = True).bert | |
elif config.encoder_model_type == 'roberta': | |
self.encoder = RobertaForSequenceClassification.from_pretrained( | |
'roberta-base', num_labels = config.num_labels, output_hidden_states = True) | |
elif config.encoder_model_type == 'xlnet': | |
self.encoder = XLNetForSequenceClassification.from_pretrained( | |
'xlnet-base-cased', num_labels = config.num_labels, output_hidden_states = True) | |
print('Config is', config) | |
if config.negative_sampling == 'none': | |
if config.arch_type == 1: | |
self.fc1 = nn.Linear(5 * config.hidden_size, 512 if config.semsup else config.num_labels) | |
elif self.arch_type == 3: | |
self.fc1 = nn.Linear(config.hidden_size, 256 if config.semsup else config.num_labels) | |
if self.coil: | |
self.tok_proj = nn.Linear(self.encoder.config.hidden_size, self.token_dim) | |
self.dropout = nn.Dropout(0.1) | |
self.candidates_topk = 10 | |
if config.negative_sampling != 'none': | |
self.group_y = np.array([np.array([l for l in group]) for group in config.group_y]) | |
#np.load('datasets/EUR-Lex/label_group_lightxml_0.npy', allow_pickle=True) | |
self.negative_sampling = config.negative_sampling | |
self.min_positive_samples = 20 | |
self.semsup = config.semsup | |
self.label_projection = None | |
if self.semsup:# and config.hidden_size != config.label_hidden_size: | |
if self.arch_type == 1: | |
self.label_projection = nn.Linear(512, config.label_hidden_size, bias= False) | |
elif self.arch_type == 2: | |
self.label_projection = nn.Linear(self.encoder.config.hidden_size, config.label_hidden_size, bias= False) | |
elif self.arch_type == 3: | |
self.label_projection = nn.Linear(256, config.label_hidden_size, bias= False) | |
# self.post_init() | |
def compute_tok_score_cart(self, doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask): | |
if not self.colbert: | |
qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1 | |
doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD | |
exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD | |
exact_match = exact_match.float() | |
scores_no_masking = torch.matmul( | |
qry_reps.view(-1, self.token_dim), # (Q * LQ) * d | |
doc_reps.view(-1, self.token_dim).transpose(0, 1) # d * (D * LD) | |
) | |
scores_no_masking = scores_no_masking.view( | |
*qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD | |
if self.colbert: | |
scores, _ = scores_no_masking.max(dim=3) | |
else: | |
scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D | |
tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1) | |
return tok_scores | |
def coil_eval_forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
desc_input_ids = None, | |
desc_attention_mask = None, | |
lab_reps = None, | |
label_embeddings = None, | |
clustered_input_ids = None, | |
clustered_desc_ids = None, | |
): | |
outputs_doc, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) | |
doc_reps = self.tok_proj(outputs_doc.last_hidden_state) # D * LD * d | |
# lab_reps = self.tok_proj(outputs_lab.last_hidden_state @ self.label_projection.weight) # Q * LQ * d | |
if clustered_input_ids is None: | |
tok_scores = self.compute_tok_score_cart( | |
doc_reps, input_ids, | |
lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask | |
) | |
else: | |
tok_scores = self.compute_tok_score_cart( | |
doc_reps, clustered_input_ids, | |
lab_reps, clustered_desc_ids.reshape(-1, clustered_desc_ids.shape[-1]), desc_attention_mask | |
) | |
logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) | |
new_tok_scores = torch.zeros(logits.shape, device = logits.device) | |
for i in range(tok_scores.shape[1]): | |
stride = tok_scores.shape[0]//tok_scores.shape[1] | |
new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] | |
logits += new_tok_scores.contiguous() | |
return logits | |
def coil_forward(self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
desc_input_ids: Optional[List[int]] = None, | |
desc_attention_mask: Optional[List[int]] = None, | |
desc_inputs_embeds: Optional[torch.Tensor] = None, | |
return_dict: Optional[bool] = None, | |
clustered_input_ids = None, | |
clustered_desc_ids = None, | |
ignore_label_embeddings_and_out_lab = None, | |
): | |
# print(desc_input_ids.shape, desc_attention_mask.shape, desc_inputs_embeds.shape) | |
outputs_doc, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) | |
if ignore_label_embeddings_and_out_lab is not None: | |
outputs_lab, label_embeddings = outputs_lab, label_embeddings | |
else: | |
outputs_lab, label_embeddings, _, _ = self.forward_label_embeddings(None, None, desc_input_ids = desc_input_ids, desc_attention_mask = desc_attention_mask, return_hidden_states = True, desc_inputs_embeds = desc_inputs_embeds) | |
doc_reps = self.tok_proj(outputs_doc.last_hidden_state) # D * LD * d | |
lab_reps = self.tok_proj(outputs_lab.last_hidden_state @ self.label_projection.weight) # Q * LQ * d | |
if clustered_input_ids is None: | |
tok_scores = self.compute_tok_score_cart( | |
doc_reps, input_ids, | |
lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask | |
) | |
else: | |
tok_scores = self.compute_tok_score_cart( | |
doc_reps, clustered_input_ids, | |
lab_reps, clustered_desc_ids.reshape(-1, clustered_desc_ids.shape[-1]), desc_attention_mask | |
) | |
logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) | |
new_tok_scores = torch.zeros(logits.shape, device = logits.device) | |
for i in range(tok_scores.shape[1]): | |
stride = tok_scores.shape[0]//tok_scores.shape[1] | |
new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] | |
logits += new_tok_scores.contiguous() | |
loss_fn = BCEWithLogitsLoss() | |
loss = loss_fn(logits, labels) | |
if not return_dict: | |
output = (logits,) + outputs_doc[2:] + (logits,) | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs_doc.hidden_states, | |
attentions=outputs_doc.attentions, | |
) | |
def semsup_forward(self, input_embeddings, label_embeddings, num_candidates = -1, list_to_set_mapping = None, same_labels = False): | |
''' | |
If same_labels = True, directly apply matrix multiplication | |
else: num_candidates must not be -1, list_to_set_mapping must not be None | |
''' | |
if same_labels: | |
logits = torch.bmm(input_embeddings.unsqueeze(1), label_embeddings.transpose(2,1)).squeeze(1) | |
else: | |
# TODO: Can we optimize this? Perhaps torch.bmm? | |
logits = torch.stack( | |
# For each batch point, calculate corresponding product with label embeddings | |
[ | |
logit @ label_embeddings[list_to_set_mapping[i*num_candidates: (i+1) * num_candidates]].T for i,logit in enumerate(input_embeddings) | |
] | |
) | |
return logits | |
def forward_label_embeddings(self, all_candidate_labels, label_desc_ids, desc_input_ids = None, desc_attention_mask = None, desc_inputs_embeds = None, return_hidden_states = False): | |
# Given the candidates, and corresponding | |
# description numbers of labels | |
# Returns the embeddings for unique label descriptions | |
if desc_attention_mask is None: | |
num_candidates = all_candidate_labels.shape[1] | |
# Create a set to perform minimal number of operations on common labels | |
label_desc_ids_list = list(zip(itertools.chain(*label_desc_ids.detach().cpu().tolist()), itertools.chain(*all_candidate_labels.detach().cpu().tolist()))) | |
print('Original Length: ', len(label_desc_ids_list)) | |
label_desc_ids_set = torch.tensor(list(set(label_desc_ids_list))) | |
print('New Length: ', label_desc_ids_set.shape) | |
m1 = {tuple(x):i for i, x in enumerate(label_desc_ids_set.tolist())} | |
list_to_set_mapping = torch.tensor([m1[x] for x in label_desc_ids_list]) | |
descs = [ | |
self.tokenizedDescriptions[self.config.id2label[desc_lb[1].item()]][desc_lb[0]] for desc_lb in label_desc_ids_set | |
] | |
label_input_ids = torch.cat([ | |
desc['input_ids'] for desc in descs | |
]) | |
label_attention_mask = torch.cat([ | |
desc['attention_mask'] for desc in descs | |
]) | |
label_token_type_ids = torch.cat([ | |
desc['token_type_ids'] for desc in descs | |
]) | |
label_input_ids = label_input_ids.to(label_desc_ids.device) | |
label_attention_mask = label_attention_mask.to(label_desc_ids.device) | |
label_token_type_ids = label_token_type_ids.to(label_desc_ids.device) | |
label_embeddings = self.label_model( | |
label_input_ids, | |
attention_mask=label_attention_mask, | |
token_type_ids=label_token_type_ids, | |
).pooler_output | |
else: | |
list_to_set_mapping = None | |
num_candidates = None | |
if desc_inputs_embeds is not None: | |
outputs = self.label_model( | |
inputs_embeds = desc_inputs_embeds.reshape(desc_inputs_embeds.shape[0] * desc_inputs_embeds.shape[1],desc_inputs_embeds.shape[2], desc_inputs_embeds.shape[3]).contiguous(), | |
attention_mask=desc_attention_mask.reshape(-1, desc_input_ids.shape[-1]).contiguous(), | |
) | |
else: | |
outputs = self.label_model( | |
desc_input_ids.reshape(-1, desc_input_ids.shape[-1]).contiguous(), | |
attention_mask=desc_attention_mask.reshape(-1, desc_input_ids.shape[-1]).contiguous(), | |
) | |
label_embeddings = outputs.pooler_output | |
if self.label_projection is not None: | |
if return_hidden_states: | |
return outputs, label_embeddings @ self.label_projection.weight, list_to_set_mapping, num_candidates | |
else: | |
return label_embeddings @ self.label_projection.weight, list_to_set_mapping, num_candidates | |
else: | |
return label_embeddings, list_to_set_mapping, num_candidates | |
def forward_input_encoder(self, input_ids, attention_mask, token_type_ids, ): | |
outputs = self.encoder( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
output_hidden_states=True if self.arch_type == 1 else False, | |
) | |
# Currently, method specified in LightXML is used | |
if self.arch_type in [2,3]: | |
logits = outputs[1] | |
elif self.arch_type == 1: | |
logits = torch.cat([outputs.hidden_states[-i][:, 0] for i in range(1, 5+1)], dim=-1) | |
if self.arch_type in [1,3]: | |
logits = self.dropout(logits) | |
# No Sampling | |
if self.arch_type in [1,3]: | |
logits = self.fc1(logits) | |
return outputs, logits | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
cluster_labels: Optional[torch.Tensor] = None, | |
all_candidate_labels: Optional[torch.Tensor] = None, | |
label_desc_ids: Optional[List[int]] = None, | |
desc_inputs_embeds : Optional[torch.Tensor] = None, | |
desc_input_ids: Optional[List[int]] = None, | |
desc_attention_mask: Optional[List[int]] = None, | |
label_embeddings : Optional[torch.Tensor] = None, | |
clustered_input_ids: Optional[torch.Tensor] = None, | |
clustered_desc_ids: Optional[torch.Tensor] = None, | |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
if self.coil: | |
return self.coil_forward( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
labels, | |
desc_input_ids, | |
desc_attention_mask, | |
desc_inputs_embeds, | |
return_dict, | |
clustered_input_ids, | |
clustered_desc_ids, | |
) | |
# STEP 2: Forward pass through the input model | |
outputs, logits = self.forward_input_encoder(input_ids, attention_mask, token_type_ids) | |
if self.semsup: | |
if desc_input_ids is None: | |
all_candidate_labels = torch.arange(labels.shape[1]).repeat((labels.shape[0], 1)) | |
label_embeddings, list_to_set_mapping, num_candidates = self.forward_label_embeddings(all_candidate_labels, label_desc_ids) | |
logits = self.semsup_forward(logits, label_embeddings, num_candidates, list_to_set_mapping) | |
else: | |
label_embeddings, _, _ = self.forward_label_embeddings(None, None, desc_input_ids = desc_input_ids, desc_attention_mask = desc_attention_mask, desc_inputs_embeds = desc_inputs_embeds) | |
logits = self.semsup_forward(logits, label_embeddings.reshape(desc_input_ids.shape[0], desc_input_ids.shape[1], -1).contiguous(), same_labels= True) | |
elif label_embeddings is not None: | |
logits = self.semsup_forward(logits, label_embeddings.contiguous() @ self.label_projection.weight, same_labels= True) | |
loss_fn = BCEWithLogitsLoss() | |
loss = loss_fn(logits, labels) | |
if not return_dict: | |
output = (logits,) + outputs[2:] + (logits,) | |
return ((loss,) + output) if loss is not None else output | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |