File size: 2,991 Bytes
3494c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import transformers
import torch
import os
import numpy as np
import datetime
import struct
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
def get_inner_params(named_parameters, inner_names):
param_dict = dict(named_parameters)
return [(n, param_dict[n]) for n in inner_names]
def param_subset(named_parameters, inner_names):
param_dict = dict(named_parameters)
return [param_dict[n] for n in inner_names]
def parent_module(model, pname):
components = pname.split('.')
parent = model
for component in components[:-1]:
if hasattr(parent, component):
parent = getattr(parent, component)
elif component.isdigit():
parent = parent[int(component)]
else:
raise RuntimeError(f"Couldn't find child module {component}")
if not hasattr(parent, components[-1]):
raise RuntimeError(f"Couldn't find child module {components[-1]}")
return parent
def uuid(digits=4):
if not hasattr(uuid, "uuid_value"):
uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
return uuid.uuid_value
def ckpt_dir():
"""returns the directory in which to store model checkpoints"""
path = "./ckpts/"
if not os.path.exists(path):
os.makedirs(path)
return path
def brackets_to_periods(name):
return name.replace("[", ".").replace("]", "")
def get_params(model):
return model.state_dict()
def get_shape(p, model):
# We need to flip the shapes since OpenAI gpt2 uses convs instead of linear
return p.shape if isinstance(model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
def get_logits(x):
return x.logits if hasattr(x, "logits") else x
def tokenize(batch, tokenizer, device, test=False):
prompt, label = batch["prompt"], batch["target_new"]
if not isinstance(prompt, list):
prompt=[prompt]
if not isinstance(label, list):
label=[label]
mask_token = -100 # ignore_index of CrossEntropyLoss
if test or not label:
tokens = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)
tokens["labels"] = tokens["input_ids"].clone()
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
else:
full_prompt = [f"{p} {l}" for p, l in zip(prompt, label)]
prompt_ids = tokenizer(list(prompt), return_tensors="pt", padding=True, truncation=True)["input_ids"]
num_prompt_toks = [int((i != tokenizer.pad_token_id).sum()) for i in prompt_ids]
tokens = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True)
tokens["labels"] = tokens["input_ids"].clone()
for i in range(len(prompt)):
tokens["labels"][i][:num_prompt_toks[i]] = mask_token
tokens["labels"][tokens["input_ids"] == tokenizer.pad_token_id] = mask_token
tokens = {f"{k1}" : v1.to(device) for k1, v1 in tokens.items()}
return tokens
|