homeway's picture
Add application file
7713b1f
raw
history blame
12.2 kB
"""
This is a hacky little attempt using the tools from the trigger creation script to identify a
good set of label strings. The idea is to train a linear classifier over the predict token and
then look at the most similar tokens.
"""
import os.path
import numpy as np
import logging
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM
)
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from tqdm import tqdm
from . import augments, utils, model_wrapper
logger = logging.getLogger(__name__)
def get_final_embeddings(model):
if isinstance(model, BertForMaskedLM):
return model.cls.predictions.transform
elif isinstance(model, RobertaForMaskedLM):
return model.lm_head.layer_norm
elif isinstance(model, GPT2LMHeadModel):
return model.transformer.ln_f
elif isinstance(model, GPTNeoForCausalLM):
return model.transformer.ln_f
elif isinstance(model, XLNetLMHeadModel):
return model.transformer.dropout
elif "opt" in model.name_or_path:
return model.model.decoder.final_layer_norm
elif "glm" in model.name_or_path:
return model.glm.transformer.layers[35]
elif "llama" in model.name_or_path:
return model.model.norm
else:
raise NotImplementedError(f'{model} not currently supported')
def get_word_embeddings(model):
if isinstance(model, BertForMaskedLM):
return model.cls.predictions.decoder.weight
elif isinstance(model, RobertaForMaskedLM):
return model.lm_head.decoder.weight
elif isinstance(model, GPT2LMHeadModel):
return model.lm_head.weight
elif isinstance(model, GPTNeoForCausalLM):
return model.lm_head.weight
elif isinstance(model, XLNetLMHeadModel):
return model.lm_loss.weight
elif "opt" in model.name_or_path:
return model.lm_head.weight
elif "glm" in model.name_or_path:
return model.glm.transformer.final_layernorm.weight
elif "llama" in model.name_or_path:
return model.lm_head.weight
else:
raise NotImplementedError(f'{model} not currently supported')
def random_prompt(args, tokenizer, device):
prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0)
return prompt_ids
def topk_search(args, largest=True):
utils.set_seed(args.seed)
device = args.device
logger.info('Loading model, tokenizer, etc.')
config, model, tokenizer = utils.load_pretrained(args, args.model_name)
model.to(device)
logger.info('Loading datasets')
collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
datasets = utils.load_datasets(args, tokenizer)
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
predictor = model_wrapper.ModelWrapper(model, tokenizer)
mask_cnt = torch.zeros([tokenizer.vocab_size])
phar = tqdm(enumerate(train_loader))
with torch.no_grad():
count = 0
for step, model_inputs in phar:
count += len(model_inputs["input_ids"])
prompt_ids = random_prompt(args, tokenizer, device)
logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
_, top = logits.topk(args.k, largest=largest)
ids, frequency = torch.unique(top.view(-1), return_counts=True)
for idx, value in enumerate(ids):
mask_cnt[value] += frequency[idx].detach().cpu()
phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}")
if count > 10000:
break
top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k)
tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist())
key = "topk" if largest else "lastk"
print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}")
if os.path.exists(args.output):
best_results = torch.load(args.output)
best_results[key] = top_ids
torch.save(best_results, args.output)
class OutputStorage:
"""
This object stores the intermediate gradients of the output a the given PyTorch module, which
otherwise might not be retained.
"""
def __init__(self, module):
self._stored_output = None
module.register_forward_hook(self.hook)
def hook(self, module, input, output):
self._stored_output = output
def get(self):
return self._stored_output
def label_search(args):
device = args.device
utils.set_seed(args.seed)
logger.info('Loading model, tokenizer, etc.')
config, model, tokenizer = utils.load_pretrained(args, args.model_name)
model.to(device)
final_embeddings = get_final_embeddings(model)
embedding_storage = OutputStorage(final_embeddings)
word_embeddings = get_word_embeddings(model)
label_map = args.label_map
reverse_label_map = {y: x for x, y in label_map.items()}
# The weights of this projection will help identify the best label words.
projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype)
projection.to(device)
# Obtain the initial trigger tokens and label mapping
if args.prompt:
prompt_ids = tokenizer.encode(
args.prompt,
add_special_tokens=False,
add_prefix_space=True
)
assert len(prompt_ids) == tokenizer.num_prompt_tokens
else:
if "llama" in args.model_name:
prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist()
elif "gpt" in args.model_name:
#prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens
prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
elif "opt" in args.model_name:
prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
else:
prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens
prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
logger.info('Loading datasets')
collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
datasets = utils.load_datasets(args, tokenizer)
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator)
optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
int(args.iters * len(train_loader)),
)
tot_steps = len(train_loader)
projection.to(word_embeddings.device)
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
scores = F.softmax(scores, dim=0)
for i, row in enumerate(scores):
_, top = row.topk(args.k)
decoded = tokenizer.convert_ids_to_tokens(top)
logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
best_results = {
"best_acc": 0.0,
"template": args.template,
"model_name": args.model_name,
"dataset_name": args.dataset_name,
"task": args.task
}
logger.info('Training')
for iters in range(args.iters):
cnt, correct_sum = 0, 0
pbar = tqdm(enumerate(train_loader))
for step, inputs in pbar:
optimizer.zero_grad()
prompt_mask = inputs.pop('prompt_mask').to(device)
predict_mask = inputs.pop('predict_mask').to(device)
model_inputs = {}
model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
with torch.no_grad():
model(**model_inputs)
embeddings = embedding_storage.get()
predict_mask = predict_mask.to(args.device)
projection = projection.to(args.device)
label = inputs["label"].to(args.device)
if "opt" in args.model_name and False:
predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous()
else:
predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
logits = projection(predict_embeddings)
loss = F.cross_entropy(logits, label)
pred = logits.argmax(dim=1)
correct = pred.view_as(label).eq(label).sum().detach().cpu()
loss.backward()
if "opt" in args.model_name:
torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2)
optimizer.step()
scheduler.step()
cnt += len(label)
correct_sum += correct
for param_group in optimizer.param_groups:
current_lr = param_group['lr']
del inputs
pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}')
train_accuracy = float(correct_sum/cnt)
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
scores = F.softmax(scores, dim=0)
best_results["score"] = scores.detach().cpu().numpy()
for i, row in enumerate(scores):
_, top = row.topk(args.k)
decoded = tokenizer.convert_ids_to_tokens(top)
best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu()
best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded)
print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
print()
if iters < 20:
continue
cnt, correct_sum = 0, 0
pbar = tqdm(dev_loader)
for inputs in pbar:
label = inputs["label"].to(device)
prompt_mask = inputs.pop('prompt_mask').to(device)
predict_mask = inputs.pop('predict_mask').to(device)
model_inputs = {}
model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
with torch.no_grad():
model(**model_inputs)
embeddings = embedding_storage.get()
predict_mask = predict_mask.to(embeddings.device)
projection = projection.to(embeddings.device)
label = label.to(embeddings.device)
predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
logits = projection(predict_embeddings)
pred = logits.argmax(dim=1)
correct = pred.view_as(label).eq(label).sum()
cnt += len(label)
correct_sum += correct
accuracy = float(correct_sum / cnt)
print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}")
if accuracy > best_results["best_acc"]:
best_results["best_acc"] = accuracy
for i, row in enumerate(scores):
best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"]
best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"]
print()
torch.save(best_results, args.output)
if __name__ == '__main__':
args = augments.get_args()
if args.debug:
level = logging.DEBUG
else:
level = logging.INFO
logging.basicConfig(level=level)
label_search(args)
topk_search(args, largest=True)