|
import os |
|
import copy |
|
|
|
import torch |
|
import numpy as np |
|
import random as rn |
|
import pandas as pd |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from typing import List, Optional |
|
|
|
|
|
class Namespace: |
|
def __init__(self, **kwargs): |
|
self.__dict__.update(kwargs) |
|
|
|
|
|
|
|
def load_tok(model_name="gpt2-xl"): |
|
""" Load tokenizer from transformers package |
|
""" |
|
from transformers import AutoTokenizer |
|
|
|
if model_name == "gpt-j-6b": |
|
|
|
model = "EleutherAI/gpt-j-6b" |
|
tok = AutoTokenizer.from_pretrained(model) |
|
tok.pad_token = tok.eos_token |
|
|
|
elif model_name == "gpt2-xl": |
|
|
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
tok.pad_token = tok.eos_token |
|
|
|
elif model_name == 'llama-3-8b': |
|
|
|
model = "meta-llama/Meta-Llama-3-8B" |
|
tok = AutoTokenizer.from_pretrained(model) |
|
tok.pad_token = tok.eos_token |
|
|
|
elif model_name == 'mamba-1.4b': |
|
|
|
model = 'state-spaces/mamba-1.4b-hf' |
|
tok = AutoTokenizer.from_pretrained(model) |
|
|
|
else: |
|
raise AssertionError("model_name not supported:", model_name) |
|
|
|
return tok |
|
|
|
|
|
def load_model_tok(model_name="gpt2-xl"): |
|
""" Load model and tokenizer from transformers package |
|
""" |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
if model_name == "gpt-j-6b": |
|
|
|
model = "EleutherAI/gpt-j-6b" |
|
tok = AutoTokenizer.from_pretrained(model) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
).cuda() |
|
tok.pad_token = tok.eos_token |
|
|
|
elif model_name == "gpt2-xl": |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).cuda() |
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
elif model_name == 'llama-3-8b': |
|
|
|
model = "meta-llama/Meta-Llama-3-8B" |
|
tok = AutoTokenizer.from_pretrained(model) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
).cuda() |
|
tok.pad_token = tok.eos_token |
|
|
|
elif model_name == 'mamba-1.4b': |
|
|
|
from transformers import MambaForCausalLM |
|
|
|
model = 'state-spaces/mamba-1.4b-hf' |
|
tok = AutoTokenizer.from_pretrained(model) |
|
model = MambaForCausalLM.from_pretrained(model).cuda() |
|
|
|
else: |
|
raise AssertionError("model_name not supported:", model_name) |
|
|
|
return model, tok |
|
|
|
|
|
|
|
def load_activation(activation_name): |
|
""" Load activation function from transformers package |
|
""" |
|
from transformers import activations |
|
|
|
if activation_name.lower() == "gelu": |
|
activation = activations.NewGELUActivation() |
|
elif activation_name.lower() == "gelu_org": |
|
activation = activations.GELUActivation() |
|
elif activation_name.lower() == "silu": |
|
activation = activations.silu |
|
elif activation_name.lower() == "relu": |
|
activation = activations.ACT2CLS['relu']() |
|
else: |
|
raise AssertionError("Activation not supported:", activation_name) |
|
return activation |
|
|
|
|
|
def load_dataset( |
|
tok = None, |
|
ds_name = "mcf", |
|
DATA_DIR = "data", |
|
selection = None, |
|
dataset_size_limit = None, |
|
reverse_selection = False, |
|
reverse_target = False, |
|
whole_prompt = True |
|
): |
|
""" Load dataset from MEMIT/ROME |
|
""" |
|
from dsets import ( |
|
CounterFactDataset, |
|
MENDQADataset, |
|
MultiCounterFactDataset, |
|
) |
|
from evaluation.py.eval_utils_counterfact import compute_rewrite_quality_counterfact |
|
from evaluation.py.eval_utils_zsre import compute_rewrite_quality_zsre |
|
|
|
DS_DICT = { |
|
"mcf": (MultiCounterFactDataset, compute_rewrite_quality_counterfact), |
|
"cf": (CounterFactDataset, compute_rewrite_quality_counterfact), |
|
"zsre": (MENDQADataset, compute_rewrite_quality_zsre), |
|
} |
|
|
|
ds_class, ds_eval_method = DS_DICT[ds_name] |
|
ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit) |
|
|
|
try: |
|
ds.data |
|
except: |
|
ds.data = ds._data |
|
|
|
if selection: |
|
if type(selection)==str: selection = loadjson(selection)['case_ids'] |
|
if not reverse_selection: |
|
ds.data = [d for d in ds.data if (d['case_id'] in selection)] |
|
else: |
|
ds.data = [d for d in ds.data if (d['case_id'] not in selection)] |
|
print('After selection:', len(ds.data), 'elements') |
|
|
|
if reverse_target: |
|
|
|
for i in range(len(ds.data)): |
|
request = copy.deepcopy(ds.data[i]['requested_rewrite']) |
|
|
|
tmp_true = copy.deepcopy(request['target_true']) |
|
tmp_new = copy.deepcopy(request['target_new']) |
|
|
|
request['target_new'] = tmp_true |
|
request['target_true'] = tmp_new |
|
|
|
ds.data[i]['requested_rewrite'] = request |
|
|
|
print('Target new and true reversed') |
|
|
|
if whole_prompt: |
|
|
|
for i in range(len(ds.data)): |
|
org_request = copy.deepcopy(ds.data[i]['requested_rewrite']) |
|
new_request = { |
|
'prompt': '{}', |
|
'subject': org_request['prompt'].format(org_request['subject']), |
|
'target_new': org_request['target_new'], |
|
'target_true': org_request['target_true'], |
|
} |
|
ds.data[i]['requested_rewrite'] = new_request |
|
|
|
print('Whole prompts for dataset samples') |
|
|
|
return ds, ds_class, ds_eval_method |
|
|
|
|
|
def assure_path_exists(path, create=True, out=True): |
|
"""Checks if path exists, if not then create the corresponding path |
|
|
|
Args: |
|
path (str): folder path or dir path |
|
create (bool, optional): create path if it does not exist. Defaults to True. |
|
""" |
|
|
|
dir = os.path.dirname(path) |
|
|
|
if not (dir.endswith('/') or dir.endswith('\\')): |
|
dir = dir + '/' |
|
|
|
if not os.path.exists(dir): |
|
if create: |
|
os.makedirs(dir) |
|
if out: print("PATH CREATED:", path) |
|
else: |
|
if out: print("PATH DOES NOT EXIST:", path) |
|
else: |
|
if out: print("PATH EXISTS:", path) |
|
|
|
def path_all_files(path): |
|
""" list of files in all subdirectories |
|
""" |
|
list_of_files = os.listdir(path) |
|
all_files = list() |
|
for item in list_of_files: |
|
p = os.path.join(path, item) |
|
if os.path.isdir(p): |
|
all_files = all_files + path_all_files(p) |
|
else: |
|
all_files.append(p) |
|
return all_files |
|
|
|
|
|
|
|
def savepickle(file_name, data): |
|
""" Save dict as pickle file |
|
""" |
|
import pickle |
|
with open(file_name, 'wb') as handle: |
|
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
def loadpickle(file_name): |
|
""" Load pickle file as dict |
|
""" |
|
import pickle |
|
with open(file_name, 'rb') as handle: |
|
data = pickle.load(handle) |
|
return data |
|
|
|
def loadjson(file_name): |
|
import json |
|
with open(file_name, 'r') as f: |
|
json_content = json.load(f) |
|
return json_content |
|
|
|
|
|
def savejson(file_name, data): |
|
import json |
|
with open(file_name, 'w') as f: |
|
json.dump(data, f) |
|
|
|
|
|
def load_from_cache(file_path, verbose=False, allow_fail=True): |
|
""" Function ot load a cached pickle file |
|
""" |
|
if os.path.isfile(file_path): |
|
|
|
try: |
|
if verbose: print('Loading fcloud from cache...') |
|
cache_contents = loadpickle(file_path) |
|
return cache_contents |
|
except: |
|
if allow_fail: raise AssertionError('Load cache fail:', file_path) |
|
|
|
else: |
|
if allow_fail: raise AssertionError('File not found:', file_path) |
|
return None |
|
|
|
|
|
def comp(item1, item2, out=False, cfn=False, to_list=False): |
|
""" Efficient Comparison between two sequences |
|
""" |
|
item1 = set(item1) |
|
item2 = set(item2) |
|
both = item1.intersection(item2) |
|
only1 = item1 - item2 |
|
only2 = item2 - item1 |
|
if out: |
|
print('No. of items only in variable 1: ', len(only1)) |
|
print('No. of items only in variable 2: ', len(only2)) |
|
print('No. of items both variable 1 & 2:', len(both)) |
|
|
|
if to_list: |
|
only1 = list(only1) |
|
only2 = list(only2) |
|
both = list(both) |
|
|
|
if cfn: |
|
assert len(both)==0 |
|
else: |
|
return only1, only2 , both |
|
|
|
|
|
def convert_to_subjects_prompts(requests): |
|
subjects = [r['subject'] for r in requests] |
|
prompts = [r['prompt'] for r in requests] |
|
return {'subjects': subjects, 'prompts': prompts} |
|
|
|
|
|
def smart_matmul(a, b, device='cuda'): |
|
""" Type-independent matrix multiplication |
|
""" |
|
|
|
if a.dtype in [np.float64, np.float32]: |
|
a = np.array(a, dtype=np.float16) |
|
if b.dtype in [np.float64, np.float32]: |
|
b = np.array(b, dtype=np.float16) |
|
if a.dtype == np.float16: |
|
a = torch.from_numpy(a) |
|
if b.dtype == np.float16: |
|
b = torch.from_numpy(b) |
|
if a.dtype == torch.float32: |
|
a = a.half() |
|
if b.dtype == torch.float32: |
|
b = b.half() |
|
|
|
try: |
|
a = a.to(device) |
|
b = b.to(device) |
|
except: |
|
pass |
|
|
|
|
|
r = torch.matmul(a, b) |
|
|
|
|
|
try: |
|
r = r.cpu().item() |
|
except: |
|
r = r.cpu().numpy() |
|
return r |
|
|
|
|
|
|
|
def shuffle(*arrays, **kwargs): |
|
from sklearn.utils import shuffle |
|
return shuffle(*arrays, **kwargs) |
|
|
|
def shuffle_list(l): |
|
if type(l)!=list: l = list(l) |
|
rn.shuffle(l) |
|
return l |
|
|
|
|
|
def generate_mask(list1, list2): |
|
""" Generate mask of list 1 by contents of list 2 |
|
""" |
|
|
|
mask = np.zeros(len(list1)) |
|
for i in range(len(list2)): |
|
indices = np.where(list1==list2[i])[0] |
|
mask[indices] = 1 |
|
return np.array(mask, dtype=bool) |
|
|
|
def generate_loc(list1, list2, inverse=False, verbose=0): |
|
""" Generate locations of list 2 items in list 1 |
|
""" |
|
|
|
list1 = np.array(list1) |
|
list2 = np.array(list2) |
|
|
|
locs = [] |
|
for i in range(len(list2)): |
|
indices = np.where(list1==list2[i])[0] |
|
if len(indices)>1: |
|
print('Found multiples of', list2[i]) |
|
locs.append(indices[0]) |
|
|
|
if inverse: |
|
all_locs = np.arange(len(list1)) |
|
o1, o2, bt = comp(all_locs, locs) |
|
return np.array(list(o1), dtype=int) |
|
|
|
return np.array(locs, dtype=int) |
|
|
|
|
|
def filter_for_selection(dictionary, boolean_mask): |
|
""" Filter dictionary for boolean mask |
|
""" |
|
for key in dictionary: |
|
if type(dictionary[key]) == list: |
|
dictionary[key] = np.array(dictionary[key])[boolean_mask] |
|
elif type(dictionary[key]) == np.ndarray: |
|
dictionary[key] = dictionary[key][boolean_mask] |
|
return dictionary |
|
|
|
|
|
def smart_mean_std(data, axis=None): |
|
""" Calculate mean and standard deviation of data, ignoring NaN and Inf values |
|
""" |
|
|
|
data = np.array(data) |
|
|
|
|
|
mask = np.isfinite(data) |
|
filtered_data = np.where(mask, data, np.nan) |
|
|
|
|
|
mean_value = np.nanmean(filtered_data, axis=axis) |
|
std_value = np.nanstd(filtered_data, axis=axis) |
|
|
|
return mean_value, std_value |
|
|
|
|
|
def smart_mean(data, axis=None): |
|
""" Calculate mean of data, ignoring NaN and Inf values |
|
""" |
|
|
|
data = np.array(data) |
|
|
|
|
|
mask = np.isfinite(data) |
|
filtered_data = np.where(mask, data, np.nan) |
|
|
|
|
|
mean_value = np.nanmean(filtered_data, axis=axis) |
|
|
|
return mean_value |
|
|
|
def smart_std(data, axis=None): |
|
""" Calculate mean of data, ignoring NaN and Inf values |
|
""" |
|
|
|
data = np.array(data) |
|
|
|
|
|
mask = np.isfinite(data) |
|
filtered_data = np.where(mask, data, np.nan) |
|
|
|
|
|
std_value = np.nanstd(filtered_data, axis=axis) |
|
|
|
return std_value |
|
|
|
def extract_requests(ds): |
|
""" Extract essential edit requests from dataset |
|
""" |
|
|
|
requests = [] |
|
for r in ds.data: |
|
req = r['requested_rewrite'] |
|
req['case_id'] = r['case_id'] |
|
requests.append(req) |
|
return np.array(requests) |
|
|
|
|
|
def print_single_request(r): |
|
subject = r['subject'] |
|
prompt = r['prompt'] |
|
sentence = prompt.format(subject) |
|
print(f'Sentence: {sentence} | Subject: {subject}') |
|
|
|
|
|
def print_request(rs): |
|
|
|
if type(rs) == dict: |
|
print_single_request(rs) |
|
else: |
|
for r in rs: |
|
print_single_request(r) |
|
|