Spaces:
Sleeping
Sleeping
""" | |
This is a hacky little attempt using the tools from the trigger creation script to identify a | |
good set of label strings. The idea is to train a linear classifier over the predict token and | |
then look at the most similar tokens. | |
""" | |
import os.path | |
import numpy as np | |
import logging | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from transformers import ( | |
BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM | |
) | |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | |
from tqdm import tqdm | |
from . import augments, utils, model_wrapper | |
logger = logging.getLogger(__name__) | |
def get_final_embeddings(model): | |
if isinstance(model, BertForMaskedLM): | |
return model.cls.predictions.transform | |
elif isinstance(model, RobertaForMaskedLM): | |
return model.lm_head.layer_norm | |
elif isinstance(model, GPT2LMHeadModel): | |
return model.transformer.ln_f | |
elif isinstance(model, GPTNeoForCausalLM): | |
return model.transformer.ln_f | |
elif isinstance(model, XLNetLMHeadModel): | |
return model.transformer.dropout | |
elif "opt" in model.name_or_path: | |
return model.model.decoder.final_layer_norm | |
elif "glm" in model.name_or_path: | |
return model.glm.transformer.layers[35] | |
elif "llama" in model.name_or_path: | |
return model.model.norm | |
else: | |
raise NotImplementedError(f'{model} not currently supported') | |
def get_word_embeddings(model): | |
if isinstance(model, BertForMaskedLM): | |
return model.cls.predictions.decoder.weight | |
elif isinstance(model, RobertaForMaskedLM): | |
return model.lm_head.decoder.weight | |
elif isinstance(model, GPT2LMHeadModel): | |
return model.lm_head.weight | |
elif isinstance(model, GPTNeoForCausalLM): | |
return model.lm_head.weight | |
elif isinstance(model, XLNetLMHeadModel): | |
return model.lm_loss.weight | |
elif "opt" in model.name_or_path: | |
return model.lm_head.weight | |
elif "glm" in model.name_or_path: | |
return model.glm.transformer.final_layernorm.weight | |
elif "llama" in model.name_or_path: | |
return model.lm_head.weight | |
else: | |
raise NotImplementedError(f'{model} not currently supported') | |
def random_prompt(args, tokenizer, device): | |
prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() | |
prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0) | |
return prompt_ids | |
def topk_search(args, largest=True): | |
utils.set_seed(args.seed) | |
device = args.device | |
logger.info('Loading model, tokenizer, etc.') | |
config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
model.to(device) | |
logger.info('Loading datasets') | |
collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) | |
datasets = utils.load_datasets(args, tokenizer) | |
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
predictor = model_wrapper.ModelWrapper(model, tokenizer) | |
mask_cnt = torch.zeros([tokenizer.vocab_size]) | |
phar = tqdm(enumerate(train_loader)) | |
with torch.no_grad(): | |
count = 0 | |
for step, model_inputs in phar: | |
count += len(model_inputs["input_ids"]) | |
prompt_ids = random_prompt(args, tokenizer, device) | |
logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
_, top = logits.topk(args.k, largest=largest) | |
ids, frequency = torch.unique(top.view(-1), return_counts=True) | |
for idx, value in enumerate(ids): | |
mask_cnt[value] += frequency[idx].detach().cpu() | |
phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}") | |
if count > 10000: | |
break | |
top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k) | |
tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist()) | |
key = "topk" if largest else "lastk" | |
print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}") | |
if os.path.exists(args.output): | |
best_results = torch.load(args.output) | |
best_results[key] = top_ids | |
torch.save(best_results, args.output) | |
class OutputStorage: | |
""" | |
This object stores the intermediate gradients of the output a the given PyTorch module, which | |
otherwise might not be retained. | |
""" | |
def __init__(self, module): | |
self._stored_output = None | |
module.register_forward_hook(self.hook) | |
def hook(self, module, input, output): | |
self._stored_output = output | |
def get(self): | |
return self._stored_output | |
def label_search(args): | |
device = args.device | |
utils.set_seed(args.seed) | |
logger.info('Loading model, tokenizer, etc.') | |
config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
model.to(device) | |
final_embeddings = get_final_embeddings(model) | |
embedding_storage = OutputStorage(final_embeddings) | |
word_embeddings = get_word_embeddings(model) | |
label_map = args.label_map | |
reverse_label_map = {y: x for x, y in label_map.items()} | |
# The weights of this projection will help identify the best label words. | |
projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype) | |
projection.to(device) | |
# Obtain the initial trigger tokens and label mapping | |
if args.prompt: | |
prompt_ids = tokenizer.encode( | |
args.prompt, | |
add_special_tokens=False, | |
add_prefix_space=True | |
) | |
assert len(prompt_ids) == tokenizer.num_prompt_tokens | |
else: | |
if "llama" in args.model_name: | |
prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist() | |
elif "gpt" in args.model_name: | |
#prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens | |
prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() | |
elif "opt" in args.model_name: | |
prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() | |
else: | |
prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens | |
prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) | |
logger.info('Loading datasets') | |
collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) | |
datasets = utils.load_datasets(args, tokenizer) | |
train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator) | |
optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
optimizer, | |
int(args.iters * len(train_loader)), | |
) | |
tot_steps = len(train_loader) | |
projection.to(word_embeddings.device) | |
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) | |
scores = F.softmax(scores, dim=0) | |
for i, row in enumerate(scores): | |
_, top = row.topk(args.k) | |
decoded = tokenizer.convert_ids_to_tokens(top) | |
logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") | |
best_results = { | |
"best_acc": 0.0, | |
"template": args.template, | |
"model_name": args.model_name, | |
"dataset_name": args.dataset_name, | |
"task": args.task | |
} | |
logger.info('Training') | |
for iters in range(args.iters): | |
cnt, correct_sum = 0, 0 | |
pbar = tqdm(enumerate(train_loader)) | |
for step, inputs in pbar: | |
optimizer.zero_grad() | |
prompt_mask = inputs.pop('prompt_mask').to(device) | |
predict_mask = inputs.pop('predict_mask').to(device) | |
model_inputs = {} | |
model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) | |
model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) | |
model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) | |
with torch.no_grad(): | |
model(**model_inputs) | |
embeddings = embedding_storage.get() | |
predict_mask = predict_mask.to(args.device) | |
projection = projection.to(args.device) | |
label = inputs["label"].to(args.device) | |
if "opt" in args.model_name and False: | |
predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous() | |
else: | |
predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) | |
logits = projection(predict_embeddings) | |
loss = F.cross_entropy(logits, label) | |
pred = logits.argmax(dim=1) | |
correct = pred.view_as(label).eq(label).sum().detach().cpu() | |
loss.backward() | |
if "opt" in args.model_name: | |
torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2) | |
optimizer.step() | |
scheduler.step() | |
cnt += len(label) | |
correct_sum += correct | |
for param_group in optimizer.param_groups: | |
current_lr = param_group['lr'] | |
del inputs | |
pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}') | |
train_accuracy = float(correct_sum/cnt) | |
scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) | |
scores = F.softmax(scores, dim=0) | |
best_results["score"] = scores.detach().cpu().numpy() | |
for i, row in enumerate(scores): | |
_, top = row.topk(args.k) | |
decoded = tokenizer.convert_ids_to_tokens(top) | |
best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu() | |
best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded) | |
print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") | |
print() | |
if iters < 20: | |
continue | |
cnt, correct_sum = 0, 0 | |
pbar = tqdm(dev_loader) | |
for inputs in pbar: | |
label = inputs["label"].to(device) | |
prompt_mask = inputs.pop('prompt_mask').to(device) | |
predict_mask = inputs.pop('predict_mask').to(device) | |
model_inputs = {} | |
model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) | |
model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) | |
model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) | |
with torch.no_grad(): | |
model(**model_inputs) | |
embeddings = embedding_storage.get() | |
predict_mask = predict_mask.to(embeddings.device) | |
projection = projection.to(embeddings.device) | |
label = label.to(embeddings.device) | |
predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) | |
logits = projection(predict_embeddings) | |
pred = logits.argmax(dim=1) | |
correct = pred.view_as(label).eq(label).sum() | |
cnt += len(label) | |
correct_sum += correct | |
accuracy = float(correct_sum / cnt) | |
print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}") | |
if accuracy > best_results["best_acc"]: | |
best_results["best_acc"] = accuracy | |
for i, row in enumerate(scores): | |
best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"] | |
best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"] | |
print() | |
torch.save(best_results, args.output) | |
if __name__ == '__main__': | |
args = augments.get_args() | |
if args.debug: | |
level = logging.DEBUG | |
else: | |
level = logging.INFO | |
logging.basicConfig(level=level) | |
label_search(args) | |
topk_search(args, largest=True) |