RA-BART / utils.py
MrVicente's picture
added demo base code
6cf191b
raw
history blame contribute delete
No virus
6.87 kB
#############################
# Imports and Contants #
#############################
# Python modules
from enum import Enum
import os
import json
import time
# Remote packages
import torch
#############################
# utilities
#############################
class ScoringType(Enum):
DEFAULT = 'default'
MAX_PROB = 'max-prob'
INTERPOL = 'interpol'
CONSTRAINT = 'constraint'
MULTIPLE_CHOICE = 'multiple_choice'
class LossType(Enum):
DEFAULT = 'default'
CP_RP_DEF = 'cp-rp-def'
CP_DEF = 'cp-def'
PRP_NRP_DEF = 'prp-nrp-def'
class Head_Mask(Enum):
ALL = 'all'
NONE = 'none'
RANDOM = 'random'
SPECIFIC = 'specific'
class KGType(Enum):
SWOW = 'swow'
CSKG = 'cskg'
CONCEPTNET = 'conceptnet'
class Model_Type(Enum):
RELATIONS = 'relations'
MASK = 'mask'
DEFAULT = 'default'
def is_simple_mask_commonsense(self):
return self == Model_Type.MASK
def there_is_difference_between_relations(self):
return self == Model_Type.RELATIONS
class Data_Type(Enum):
ELI5 = 'eli5'
COMMONSENSE_QA = 'commonsense_qa'
COMMONGEN_QA = 'commongen_qa'
STACK_EXCHANGE = 'stackexchange_qa'
ASK_SCIENCE = 'ask_science_qa'
NATURAL_QUESTIONS = 'natural_questions'
LAMA = 'lama'
CONCEPTNET = 'conceptnet'
CUSTOM = 'custom'
COMMONGEN = 'commongen'
@staticmethod
def data_types_to_str(data_types):
datasets_str = '-'.join([x.value for x in data_types])
return datasets_str
#############################
# Models
#############################
MODELS_PRETRAINING_NAME = {
"bart_large": "facebook/bart-large",
"bart_large_fp32": "patrickvonplaten/bart-large-fp32",
"bart_large_tweak": "",
"bart_base": "facebook/bart-base"
}
CURRENT_PRETRAINING_NAME = MODELS_PRETRAINING_NAME.get('bart_large_fp32')
#############################
# Files Managment #
#############################
def create_directory(output_dir):
# Create output directory if needed
if not os.path.exists(output_dir):
try:
os.makedirs(output_dir)
except FileExistsError as _:
return
else:
print(f"Output directory {output_dir} already exists")
def read_simple_text_file_2_vec(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r') as f:
return f.read().split('\n')
def write_dict_2_json_file(json_object, filename, store_dir='.'):
create_directory(store_dir)
with open(f'{store_dir}/{filename}', 'w', encoding='utf-8') as file:
json.dump(json_object, file, ensure_ascii=False, indent=4)
def read_json_file_2_dict(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
return json.load(file)
def read_jsonl_file_2_dict(filename, store_dir='.'):
elements = []
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
for line in file:
elements.append(json.loads(line))
return elements
def read_txt_2_list(filename, store_dir='.'):
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
return file.read().split('\n')
#############################
# Data Structures helper functions
#############################
def get_chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
jump = len(lst)//n
for i in range(0, len(lst), jump):
yield lst[i:i + jump]
def get_jump_chunks(lst, jump):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), jump):
yield lst[i:i + jump]
def join_str_first(sep_str, lis):
return '{1}{0}'.format(sep_str.join(lis), sep_str).strip()
#############################
# Huggingface
#############################
def inputs_introspection_print(tokenizer, inputs):
input_ids = inputs.get('input_ids', None)
input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
labels_ids = inputs.get('labels', None)
labels_text = tokenizer.batch_decode(labels_ids, skip_special_tokens=False)
print('orginal input:', input_text[:2])
print("::::::::::::::::::::::::::")
print('orginal labels:', labels_text[:2])
print("==========|||||==========")
def tok_data_2_text(tokenizer, all_inputs):
def clean_input_text(text):
real_text = text.split(tokenizer.eos_token)[0]
real_text = real_text.replace(tokenizer.bos_token, '').strip()
return real_text
all_input_text, all_labels_text = [], []
for inputs in all_inputs:
input_ids = inputs.get('input_ids', None)
input_text = tokenizer.decode(input_ids, skip_special_tokens=False)
labels_ids = inputs.get('labels', None)
labels_text = tokenizer.decode(labels_ids, skip_special_tokens=True)
#print('input_text:', input_text)
#print('labels_text:', labels_text)
input_text = clean_input_text(input_text)
all_input_text.append(input_text)
all_labels_text.append(labels_text)
return all_input_text, all_labels_text
#############################
# Torch
#############################
def get_device(verbose:bool=True):
# If there's a GPU available...
if torch.cuda.is_available():
device = torch.device("cuda")
n_gpus = torch.cuda.device_count()
first_gpu = torch.cuda.get_device_name(0)
if verbose:
print(f'There are {n_gpus} GPU(s) available.')
print(f'GPU gonna be used: {first_gpu}')
else:
if verbose:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
return device
#############################
# Timing
#############################
def timing_decorator(func):
def wrapper(*args, **kwargs):
start = time.time()
original_return_val = func(*args, **kwargs)
end = time.time()
print("time elapsed in ", func.__name__, ": ", end - start, sep='')
return original_return_val
return wrapper
#############################
# PRINTING UTILS
#############################
class LOGGER_COLORS:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
INFOCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
def print_info(logger, message):
logger.info(f'{LOGGER_COLORS.INFOCYAN}[INFO]{LOGGER_COLORS.ENDC}: {message}')
def print_success(logger, message):
logger.info(f'{LOGGER_COLORS.OKGREEN}[SUCCESS]{LOGGER_COLORS.ENDC}: {message}')
def print_warning(logger, message):
logger.info(f'{LOGGER_COLORS.WARNING}[WARNING]{LOGGER_COLORS.ENDC}: {message}')
def print_fail(logger, message):
logger.info(f'{LOGGER_COLORS.FAIL}[FAIL]{LOGGER_COLORS.ENDC}: {message}')