File size: 6,868 Bytes
6cf191b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
#############################
# Imports and Contants #
#############################
# Python modules
from enum import Enum
import os
import json
import time
# Remote packages
import torch
#############################
# utilities
#############################
class ScoringType(Enum):
DEFAULT = 'default'
MAX_PROB = 'max-prob'
INTERPOL = 'interpol'
CONSTRAINT = 'constraint'
MULTIPLE_CHOICE = 'multiple_choice'
class LossType(Enum):
DEFAULT = 'default'
CP_RP_DEF = 'cp-rp-def'
CP_DEF = 'cp-def'
PRP_NRP_DEF = 'prp-nrp-def'
class Head_Mask(Enum):
ALL = 'all'
NONE = 'none'
RANDOM = 'random'
SPECIFIC = 'specific'
class KGType(Enum):
SWOW = 'swow'
CSKG = 'cskg'
CONCEPTNET = 'conceptnet'
class Model_Type(Enum):
RELATIONS = 'relations'
MASK = 'mask'
DEFAULT = 'default'
def is_simple_mask_commonsense(self):
return self == Model_Type.MASK
def there_is_difference_between_relations(self):
return self == Model_Type.RELATIONS
class Data_Type(Enum):
ELI5 = 'eli5'
COMMONSENSE_QA = 'commonsense_qa'
COMMONGEN_QA = 'commongen_qa'
STACK_EXCHANGE = 'stackexchange_qa'
ASK_SCIENCE = 'ask_science_qa'
NATURAL_QUESTIONS = 'natural_questions'
LAMA = 'lama'
CONCEPTNET = 'conceptnet'
CUSTOM = 'custom'
COMMONGEN = 'commongen'
@staticmethod
def data_types_to_str(data_types):
datasets_str = '-'.join([x.value for x in data_types])
return datasets_str
#############################
# Models
#############################
MODELS_PRETRAINING_NAME = {
"bart_large": "facebook/bart-large",
"bart_large_fp32": "patrickvonplaten/bart-large-fp32",
"bart_large_tweak": "",
"bart_base": "facebook/bart-base"
}
CURRENT_PRETRAINING_NAME = MODELS_PRETRAINING_NAME.get('bart_large_fp32')
#############################
# Files Managment #
#############################
def create_directory(output_dir):
# Create output directory if needed
if not os.path.exists(output_dir):
try:
os.makedirs(output_dir)
except FileExistsError as _:
return
else:
print(f"Output directory {output_dir} already exists")
def read_simple_text_file_2_vec(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r') as f:
return f.read().split('\n')
def write_dict_2_json_file(json_object, filename, store_dir='.'):
create_directory(store_dir)
with open(f'{store_dir}/{filename}', 'w', encoding='utf-8') as file:
json.dump(json_object, file, ensure_ascii=False, indent=4)
def read_json_file_2_dict(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
return json.load(file)
def read_jsonl_file_2_dict(filename, store_dir='.'):
elements = []
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
for line in file:
elements.append(json.loads(line))
return elements
def read_txt_2_list(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
return file.read().split('\n')
#############################
# Data Structures helper functions
#############################
def get_chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
jump = len(lst)//n
for i in range(0, len(lst), jump):
yield lst[i:i + jump]
def get_jump_chunks(lst, jump):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), jump):
yield lst[i:i + jump]
def join_str_first(sep_str, lis):
return '{1}{0}'.format(sep_str.join(lis), sep_str).strip()
#############################
# Huggingface
#############################
def inputs_introspection_print(tokenizer, inputs):
input_ids = inputs.get('input_ids', None)
input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
labels_ids = inputs.get('labels', None)
labels_text = tokenizer.batch_decode(labels_ids, skip_special_tokens=False)
print('orginal input:', input_text[:2])
print("::::::::::::::::::::::::::")
print('orginal labels:', labels_text[:2])
print("==========|||||==========")
def tok_data_2_text(tokenizer, all_inputs):
def clean_input_text(text):
real_text = text.split(tokenizer.eos_token)[0]
real_text = real_text.replace(tokenizer.bos_token, '').strip()
return real_text
all_input_text, all_labels_text = [], []
for inputs in all_inputs:
input_ids = inputs.get('input_ids', None)
input_text = tokenizer.decode(input_ids, skip_special_tokens=False)
labels_ids = inputs.get('labels', None)
labels_text = tokenizer.decode(labels_ids, skip_special_tokens=True)
#print('input_text:', input_text)
#print('labels_text:', labels_text)
input_text = clean_input_text(input_text)
all_input_text.append(input_text)
all_labels_text.append(labels_text)
return all_input_text, all_labels_text
#############################
# Torch
#############################
def get_device(verbose:bool=True):
# If there's a GPU available...
if torch.cuda.is_available():
device = torch.device("cuda")
n_gpus = torch.cuda.device_count()
first_gpu = torch.cuda.get_device_name(0)
if verbose:
print(f'There are {n_gpus} GPU(s) available.')
print(f'GPU gonna be used: {first_gpu}')
else:
if verbose:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
return device
#############################
# Timing
#############################
def timing_decorator(func):
def wrapper(*args, **kwargs):
start = time.time()
original_return_val = func(*args, **kwargs)
end = time.time()
print("time elapsed in ", func.__name__, ": ", end - start, sep='')
return original_return_val
return wrapper
#############################
# PRINTING UTILS
#############################
class LOGGER_COLORS:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
INFOCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def print_info(logger, message):
logger.info(f'{LOGGER_COLORS.INFOCYAN}[INFO]{LOGGER_COLORS.ENDC}: {message}')
def print_success(logger, message):
logger.info(f'{LOGGER_COLORS.OKGREEN}[SUCCESS]{LOGGER_COLORS.ENDC}: {message}')
def print_warning(logger, message):
logger.info(f'{LOGGER_COLORS.WARNING}[WARNING]{LOGGER_COLORS.ENDC}: {message}')
def print_fail(logger, message):
logger.info(f'{LOGGER_COLORS.FAIL}[FAIL]{LOGGER_COLORS.ENDC}: {message}')
|