Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| import torch.nn as nn | |
| import torch | |
| import numpy | |
| import pickle | |
| import tqdm | |
| from bert import BERT | |
| from vocab import Vocab | |
| from dataset import TokenizerDataset | |
| import argparse | |
| from itertools import combinations | |
| def generate_subset(s): | |
| subsets = [] | |
| for r in range(len(s) + 1): | |
| combinations_result = combinations(s, r) | |
| if r==1: | |
| subsets.extend(([item] for sublist in combinations_result for item in sublist)) | |
| else: | |
| subsets.extend((list(sublist) for sublist in combinations_result)) | |
| subsets_dict = {i:s for i, s in enumerate(subsets)} | |
| return subsets_dict | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-workspace_name', type=str, default=None) | |
| parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length") | |
| parser.add_argument('-pretrain', type=bool, default=False) | |
| parser.add_argument('-masked_pred', type=bool, default=False) | |
| parser.add_argument('-epoch', type=str, default=None) | |
| # parser.add_argument('-set_label', type=bool, default=False) | |
| # parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks') | |
| options = parser.parse_args() | |
| folder_path = options.workspace_name+"/" if options.workspace_name else "" | |
| # if options.set_label: | |
| # label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'}) | |
| # pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb")) | |
| # else: | |
| # label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb")) | |
| # print(f"options.label _standard: {options.label_standard}") | |
| vocab_path = f"{folder_path}check/pretraining/vocab.txt" | |
| # vocab_path = f"{folder_path}pretraining/vocab.txt" | |
| print("Loading Vocab", vocab_path) | |
| vocab_obj = Vocab(vocab_path) | |
| vocab_obj.load_vocab() | |
| print("Vocab Size: ", len(vocab_obj.vocab)) | |
| # label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb"))) | |
| # label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'}) | |
| # pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb")) | |
| if options.masked_pred: | |
| str_code = "masked_prediction" | |
| output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}" | |
| else: | |
| str_code = "masked" | |
| output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}" | |
| folder_path = folder_path+"check/" | |
| # folder_path = folder_path | |
| if options.pretrain: | |
| pretrain_file = f"{folder_path}pretraining/pretrain.txt" | |
| pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl" | |
| # pretrain_file = f"{folder_path}finetuning/train.txt" | |
| # pretrain_label = f"{folder_path}finetuning/train_label.txt" | |
| embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl" | |
| print("Loading Pretrain Dataset ", pretrain_file) | |
| pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len) | |
| print("Creating Dataloader") | |
| pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4) | |
| else: | |
| val_file = f"{folder_path}pretraining/test.txt" | |
| val_label = f"{folder_path}pretraining/test_opt.txt" | |
| # val_file = f"{folder_path}finetuning/test.txt" | |
| # val_label = f"{folder_path}finetuning/test_label.txt" | |
| embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl" | |
| print("Loading Validation Dataset ", val_file) | |
| val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len) | |
| print("Creating Dataloader") | |
| val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(device) | |
| print("Load Pre-trained BERT model...") | |
| print(output_name) | |
| bert = torch.load(output_name, map_location=device) | |
| # learned_parameters = model_ep0.state_dict() | |
| for param in bert.parameters(): | |
| param.requires_grad = False | |
| if options.pretrain: | |
| print("Pretrain-embeddings....") | |
| data_iter = tqdm.tqdm(enumerate(pretrain_data_loader), | |
| desc="pre-train", | |
| total=len(pretrain_data_loader), | |
| bar_format="{l_bar}{r_bar}") | |
| pretrain_embeddings = [] | |
| for i, data in data_iter: | |
| data = {key: value.to(device) for key, value in data.items()} | |
| hrep = bert(data["bert_input"], data["segment_label"]) | |
| # print(hrep[:,0].cpu().detach().numpy()) | |
| embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] | |
| pretrain_embeddings.extend(embeddings) | |
| pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb")) | |
| # pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb")) | |
| else: | |
| print("Validation-embeddings....") | |
| data_iter = tqdm.tqdm(enumerate(val_data_loader), | |
| desc="validation", | |
| total=len(val_data_loader), | |
| bar_format="{l_bar}{r_bar}") | |
| val_embeddings = [] | |
| for i, data in data_iter: | |
| data = {key: value.to(device) for key, value in data.items()} | |
| hrep = bert(data["bert_input"], data["segment_label"]) | |
| # print(,hrep[:,0].shape) | |
| embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] | |
| val_embeddings.extend(embeddings) | |
| pickle.dump(val_embeddings, open(embedding_file_path,"wb")) | |
| # pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb")) | |