PROJECT_PATH = 'cleaned_code' import os import sys sys.path.append(PROJECT_PATH) import numpy as np import pickle import h5py from tqdm import tqdm from transformers import AutoTokenizer from scipy.special import expit import torch from typing import Optional import json from src import BertForSemanticEmbedding, getLabelModel from src import DataTrainingArguments, ModelArguments, CustomTrainingArguments, read_yaml_config from src import dataset_classification_type from src import SemSupDataset from transformers import AutoConfig, HfArgumentParser, AutoTokenizer import torch import json from tqdm import tqdm device = 'cuda' if torch.cuda.is_available() else 'cpu' def compute_tok_score_cart(doc_reps, doc_input_ids, qry_reps, qry_input_ids, qry_attention_mask): qry_input_ids = qry_input_ids.unsqueeze(2).unsqueeze(3) # Q * LQ * 1 * 1 doc_input_ids = doc_input_ids.unsqueeze(0).unsqueeze(1) # 1 * 1 * D * LD exact_match = doc_input_ids == qry_input_ids # Q * LQ * D * LD exact_match = exact_match.float() scores_no_masking = torch.matmul( qry_reps.view(-1, 16), # (Q * LQ) * d doc_reps.view(-1, 16).transpose(0, 1) # d * (D * LD) ) scores_no_masking = scores_no_masking.view( *qry_reps.shape[:2], *doc_reps.shape[:2]) # Q * LQ * D * LD scores, _ = (scores_no_masking * exact_match).max(dim=3) # Q * LQ * D tok_scores = (scores * qry_attention_mask.reshape(-1, qry_attention_mask.shape[-1]).unsqueeze(2))[:, 1:].sum(1) return tok_scores def coil_fast_eval_forward( input_ids: Optional[torch.Tensor] = None, doc_reps = None, logits: Optional[torch.Tensor] = None, desc_input_ids = None, desc_attention_mask = None, lab_reps = None, label_embeddings = None ): tok_scores = compute_tok_score_cart( doc_reps, input_ids, lab_reps, desc_input_ids.reshape(-1, desc_input_ids.shape[-1]), desc_attention_mask ) logits = (logits.unsqueeze(0) @ label_embeddings.T) new_tok_scores = torch.zeros(logits.shape, device = logits.device) for i in range(tok_scores.shape[1]): stride = tok_scores.shape[0]//tok_scores.shape[1] new_tok_scores[i] = tok_scores[i*stride: i*stride + stride ,i] return (logits + new_tok_scores).squeeze() class DemoModel: def __init__(self, ): self.label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/all_labels.txt')] unseen_label_list = [x.strip() for x in open(f'{PROJECT_PATH}/datasets/Amzn13K/unseen_labels_split6500_2.txt')] num_labels = len(self.label_list) self.label_list.sort() # For consistency l2i = {v: i for i, v in enumerate(self.label_list)} unseen_label_indexes = [l2i[x] for x in unseen_label_list] self.coil_cluster_map = json.load(open(f'{PROJECT_PATH}/bert_coil_map_dict_lemma255K_isotropic.json')) all_lab_reps1, all_label_embeddings1, all_desc_input_ids_orig1, all_desc_input_ids1, all_desc_attention_mask1 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_1.pkl','rb')) all_lab_reps2, all_label_embeddings2, all_desc_input_ids_orig2, all_desc_input_ids2, all_desc_attention_mask2 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_2.pkl','rb')) all_lab_reps3, all_label_embeddings3, all_desc_input_ids_orig3, all_desc_input_ids3, all_desc_attention_mask3 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_3.pkl','rb')) all_lab_reps4, all_label_embeddings4, all_desc_input_ids_orig4, all_desc_input_ids4, all_desc_attention_mask4 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_4.pkl','rb')) all_lab_reps5, all_label_embeddings5, all_desc_input_ids_orig5, all_desc_input_ids5, all_desc_attention_mask5 = pickle.load(open(f'{PROJECT_PATH}/precomputed/Amzn13K/amzn_base_labels_data1_5.pkl','rb')) self.all_lab_reps = [all_lab_reps1.to(device), all_lab_reps2.to(device), all_lab_reps3.to(device), all_lab_reps4.to(device), all_lab_reps5.to(device)] self.all_label_embeddings = [all_label_embeddings1.to(device), all_label_embeddings2.to(device), all_label_embeddings3.to(device), all_label_embeddings4.to(device), all_label_embeddings5.to(device)] self.all_desc_input_ids_orig = [all_desc_input_ids_orig1.to(device), all_desc_input_ids_orig2.to(device), all_desc_input_ids_orig3.to(device), all_desc_input_ids_orig4.to(device), all_desc_input_ids_orig5.to(device)] self.all_desc_input_ids = [all_desc_input_ids1.to(device), all_desc_input_ids2.to(device), all_desc_input_ids3.to(device), all_desc_input_ids4.to(device), all_desc_input_ids5.to(device)] self.all_desc_attention_mask = [all_desc_attention_mask1.to(device), all_desc_attention_mask2.to(device), all_desc_attention_mask3.to(device), all_desc_attention_mask4.to(device), all_desc_attention_mask5.to(device)] ARGS_FILE = f'{PROJECT_PATH}/configs/ablation_amzn_eda.yml' parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) self.model_args, self.data_args, self.training_args = parser.parse_dict(read_yaml_config(ARGS_FILE, output_dir = 'demo_tmp', extra_args = {})) config = AutoConfig.from_pretrained( self.model_args.config_name if self.model_args.config_name else self.model_args.model_name_or_path, finetuning_task=self.data_args.task_name, cache_dir=self.model_args.cache_dir, revision=self.model_args.model_revision, use_auth_token=True if self.model_args.use_auth_token else None, ) config.model_name_or_path = self.model_args.model_name_or_path config.problem_type = dataset_classification_type[self.data_args.task_name] config.negative_sampling = self.model_args.negative_sampling config.semsup = self.model_args.semsup config.encoder_model_type = self.model_args.encoder_model_type config.arch_type = self.model_args.arch_type config.coil = self.model_args.coil config.token_dim = self.model_args.token_dim config.colbert = self.model_args.colbert label_model, label_tokenizer = getLabelModel(self.data_args, self.model_args) config.label_hidden_size = label_model.config.hidden_size model = BertForSemanticEmbedding(config) model.label_model = label_model model.label_tokenizer = label_tokenizer model.config.label2id = {l: i for i, l in enumerate(self.label_list)} model.config.id2label = {id: label for label, id in config.label2id.items()} self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') model.to(device) model.eval() torch.set_grad_enabled(False) model.load_state_dict(torch.load(f'{PROJECT_PATH}/ckpt/Amzn13K/amzn_main_model.bin', map_location = device)) self.model = model self.extracted_descs = [self.extract_descriptions(adi) for adi in self.all_desc_input_ids_orig] tot_len = len(self.all_desc_input_ids_orig) for i in range(len(self.all_desc_input_ids_orig[0])): for j in range(tot_len): if self.extracted_descs[j][i] == "": for k in range(tot_len): if self.extracted_descs[k][i] != '': self.extracted_descs[j][i] = self.extracted_descs[k][i] break def extract_descriptions(self, input_ids): descs = self.tokenizer.batch_decode(input_ids, skip_special_tokens = True) new_descs = [] for desc in descs: a = desc.find('description is') if a == -1: # There is no description to use, lets go with empty new_descs.append("") continue b = min([desc.find(x, a) if desc.find(x, a) !=-1 else 99999999999 for x in ['label is','parents are','children are']]) if b == 99999999999: new_descs.append(desc[a:].strip()) else: new_descs.append(desc[a:b].strip()) return new_descs def classify(self, text, unseen_labels = None): self.model.eval() with torch.no_grad(): item = self.tokenizer(text, padding='max_length', max_length=self.data_args.max_seq_length, truncation=True) item = {k:torch.tensor(v, device = device).unsqueeze(0) for k,v in item.items()} outputs_doc, logits = self.model.forward_input_encoder(**item) doc_reps = self.model.tok_proj(outputs_doc.last_hidden_state) input_ids = torch.tensor([self.coil_cluster_map[str(x.item())] for x in item['input_ids'][0]]).to(device).unsqueeze(0) all_logits = [] descriptions = [] for adi, ada, alr, ale in zip(self.all_desc_input_ids, self.all_desc_attention_mask, self.all_lab_reps, self.all_label_embeddings): all_logits.append(coil_fast_eval_forward(input_ids, doc_reps, logits, adi, ada, alr, ale)) final_logits = sum([expit(x.cpu()) for x in all_logits]) / len(all_logits) max_indices = torch.argmax(torch.stack(all_logits), dim=0).cpu().tolist() # from pdb import set_trace as bp # bp() outs = torch.topk(final_logits, k = 50) preds_dic = dict() descs_dic = dict() for i,v in zip(outs.indices, outs.values): preds_dic[self.label_list[i]] = v.item() print(self.extracted_descs[max_indices[i]][i]) descs_dic[self.label_list[i]] = self.extracted_descs[max_indices[i]][i] return preds_dic, descs_dic if __name__ == '__main__': model = DemoModel() model.classify('Hello')