#! /usr/bin/env python3 # coding=utf-8 import os import sys import argparse from tqdm import trange import torch import torch.optim import torch.nn.functional as F import numpy as np from torch.autograd import Variable lab_root = os.path.join(os.path.abspath(os.path.dirname(__file__)), '..', '..') sys.path.insert(1, lab_root) from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer from IPython import embed def top_k_logits(logits, k, probs=False): """ Masks everything but the k top entries as -infinity (1e10). Used to mask logits such that e^-infinity -> 0 won't contribute to the sum of the denominator. """ if k == 0: return logits else: values = torch.topk(logits, k)[0] batch_mins = values[:, -1].view(-1, 1).expand_as(logits) if probs: return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True, return_past=False): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) else: assert context is None, 'Specify exactly one of start_token and context!' context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) # context.requires_grad_()=True prev = context output = context past = None with torch.no_grad(): for i in trange(length, ascii=True): logits, past = model(prev, past=past) logits = logits[:, -1, :] / temperature logits = top_k_logits(logits, k=top_k) # do nothing if k=0 log_probs = F.softmax(logits, dim=-1) if sample: prev = torch.multinomial(log_probs, num_samples=1) else: _, prev = torch.topk(log_probs, k=1, dim=-1) # prev is the next character, past is something [2, 1, 16, x, 64] where x grows from 1 to length # embed() # print('sample sequence {}: prev shape {} past shape {}'.format(i, # list(prev[0].size()), list(past[0].size()))) output = torch.cat((output, prev), dim=1) #print(output) if return_past: return output, past else: return output def sample_from_hidden(model, length, hidden, context=None, past=None, temperature=1, top_k=0, device='cuda', sample=True, noise_level=1e-1): output = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) if context else None with torch.no_grad(): for i in trange(length, ascii=True): logits = model.forward_hidden(hidden) logits = logits[:, -1, :] / temperature logits = top_k_logits(logits, k=top_k) # do nothing if k=0 log_probs = F.softmax(logits, dim=-1) if sample: prev = torch.multinomial(log_probs, num_samples=1) else: _, prev = torch.topk(log_probs, k=1, dim=-1) # prev is the next character, past is something [2, 1, 16, x, 64] where x grows from 1 to length #embed() #print('sample sequence {}: prev shape {} past shape {}'.format(i, list(prev[0].size()), list(past[0].size()))) output = prev if output is None else torch.cat((output, prev), dim=1) # update output if i == 0: _, past = model(output, past=None) # update past. Take the whole input context else: _, past = model(prev, past=past) # update past. Take one next token hidden = model.hidden_states # update hidden #print('output', output) #print('hidden', hidden) # do something with the hidden hidden = modify_hidden(hidden, noise_level) return output def modify_hidden(input_tensor, noise_level=1e-1): # input_tensor shape: (1, 1, length) length = input_tensor.shape[-1] ret = input_tensor + torch.rand(length).cuda() * noise_level return ret def compute_log_likelihood(model, phrase, tokenizer, device): token_ids = tokenizer.encode(phrase) batch_size = 1 context = torch.tensor(token_ids, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) print("Computing LL of phrase \"{}\"".format(phrase)) print("After encoding, number of tokens {}".format(len(token_ids))) with torch.no_grad(): logits, past = model(context, past=None) _idxs = range(len(token_ids) - 1) token_ids = token_ids[1:] logits = logits[0, :-1] probs = F.softmax(logits, dim=-1) likelihoods = probs[_idxs, token_ids] assert len(list(likelihoods.shape)) == 1 log_likelihoods = torch.log(likelihoods) ll_list = [ls.item() for ls in log_likelihoods] for token, llh in zip(token_ids, log_likelihoods): print("LL of token {} (\'{}\') ==> {:.4f}".format(token, tokenizer.decode([token]), llh)) print("LL of the phrase (sum of the above): {}".format(np.sum(ll_list))) return np.sum(ll_list) def get_embedding_grad(model, enc, context=None, target=40, device='cuda', ll_only=False, opt_embed=False): assert context is not None, 'Input text is needed' # context = Variable(torch.tensor(context, device=device, dtype=torch.float), # requires_grad=True).unsqueeze(0)#.repeat(1, 1) context = torch.tensor(context, device=device, dtype=torch.float).unsqueeze(0) model.zero_grad() logits, past = model(context, past=None) # make sure it is the same as above # logits_1, past_1 = model.forward_embed(model.transformer.i_embeds, past=None) logits = logits[:, -1, :] log_probs = F.softmax(logits, dim=-1) if len(target) > 1: nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) else: nll = - torch.log(log_probs[:, target]) with torch.no_grad(): # logits = top_k_logits(logits, k=1) # do nothing if k=0 log_probs = F.softmax(logits, dim=-1) top1, top1ind = torch.topk(log_probs, k=1, dim=-1) print('LL of target : {}'.format(-nll.data.squeeze().cpu().numpy())) print('LL of top 1 : {}'.format(torch.log(top1).data.squeeze().cpu().numpy())) if ll_only: return if opt_embed: # optimizin in embedding space orig_embed = model.transformer.i_embeds.clone() embed_vars = Variable(model.transformer.i_embeds, requires_grad=True) # optimizer = torch.optim.SGD([embed_vars], lr=0.01, momentum=0.9) optimizer = torch.optim.Adam([embed_vars], lr=0.01) optimizer.zero_grad() for ss in range(50): # nll.backward(retain_graph=True) nll.backward() optimizer.step() logits, past = model.forward_embed(embed_vars, past=None) logits = logits[:, -1, :] log_probs = F.softmax(logits, dim=-1) if len(target) > 1: nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) else: nll = - torch.log(log_probs[:, target]) print('LL of target (step {}): {}'.format(ss, -nll.data.squeeze().cpu().numpy())) # print('Sanity check: embed_vars sum: {}'.format(embed_vars.sum().cpu().detach().numpy())) # searching in token space output_ids = torch.empty_like(context.long()) with torch.no_grad(): all_embeds = model.transformer.wte.weight # [50257, 1024] embed_vars_unbind = torch.unbind(embed_vars, dim=1) orig_embed_unbind = torch.unbind(orig_embed, dim=1) cc = 0 for ie_new, ie_orig, orig_id in zip(embed_vars_unbind, orig_embed_unbind, context.squeeze(0)): new_id = (all_embeds - ie_new).abs().sum(1).argmin() print('emb {}: {} (`{}`) to {} (`{}`)'.format(cc, orig_id.tolist(), enc.decode([orig_id.tolist()]), new_id.tolist(), enc.decode([new_id.tolist()]))) output_ids[0, cc] = new_id cc += 1 output_ids = torch.cat((context.long(), output_ids), dim=1) return output_ids ## searching in token space # model.transformer.i_embeds.retain_grad() # nll.backward() # step = 0.01 # ## with torch.no_grad(): # if True: # input_grads = model.transformer.i_embeds.grad # [batch, length, 1024] # #input_grads = input_grads.squeeze(0) # [length, 1024] # input_embeds = model.transformer.i_embeds # [batch, length, 1024] # input_embeds_unbind = torch.unbind(input_embeds, dim=1) # all_embeds = model.transformer.wte.weight # [50257, 1024] # # opts = [torch.optim.Adam([Variable(ie, requires_grad=True)], lr=0.01) for ie in input_embeds_unbind] # # ## HERE # # for ss in range(50): # # input_embeds.data.sub_(step * input_grads.data) # # #input_embeds.data.add_(step * input_grads.data) # # # # logits, past = model.forward_embed(input_embeds, past=None) # # logits = logits[:, -1, :] # # log_probs = F.softmax(logits, dim=-1) # # if len(target) > 1: # # nll = sum([-torch.log(log_probs[:, tar]) for tar in target]) # # else: # # nll = - torch.log(log_probs[:, target]) # # # # print('LL of target (step {}): {}'.format(ss, -nll.data.squeeze().cpu().numpy())) # # # # embed() # search_order = input_grads.sum(-1).squeeze().abs().argsort(descending=True) # output_ids = context.long() # cc = 0 # #n_tokens_to_change = 1 # for order, orig_id in zip(search_order, context.squeeze(0)[search_order]): # embed() # # ie = input_embeds_unbind[order] # orig_id = orig_id.long() # opt = opts[order] # opt.zero_grad() # new_id = abs(all_embeds - ie).sum(1).argmin().data # new_id == orig_id # #if cc < n_tokens_to_change: # # while new_id == orig_id: # # #ie.data.sub_(step * ig.data) # #ie.data.add_(step * ig.data) # for opt_step in range(50): # opt.step() # new_id = abs(all_embeds - ie).sum(1).argmin().data # print('emb {}: {} (`{}`) to {} (`{}`)'.format(order, orig_id.tolist(), enc.decode([orig_id.tolist()]), # new_id.tolist(), enc.decode([new_id.tolist()]))) # output_ids[0, order] = new_id # #output_ids = torch.cat((output_ids, new_id.reshape(1,1)), dim=1) # cc += 1 # output_ids = torch.cat((context.long(), output_ids), dim=1) # print(context.grad) return output_ids def run_model(): parser = argparse.ArgumentParser() parser.add_argument('--model_path', '-M', type=str, default='gpt-2_pt_models/774M/', help='pretrained model name or path to local checkpoint') parser.add_argument("--seed", type=int, default=0) parser.add_argument("--nsamples", type=int, default=1) parser.add_argument("--batch_size", type=int, default=-1) parser.add_argument("--length", type=int, default=-1) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=0) parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') parser.add_argument('--nocuda', action='store_true', help='no cuda') parser.add_argument('--opt_ll', action='store_true', help='nll optimize') parser.add_argument('--get_ll', action='store_true', help='compute log likelihood of sentence') parser.add_argument('--hidden_playground', action='store_true', help='play around in the hidden representation') parser.add_argument("--noise_level", type=float, default=1e-1) parser.add_argument("--cond-text", type=str, default='', help='Prefix texts to condition on') parser.add_argument('--output', type=str, default=os.environ.get('GIT_RESULTS_MANAGER_DIR', None), help='output directory') args = parser.parse_args() print(args) if args.batch_size == -1: args.batch_size = 1 assert args.nsamples % args.batch_size == 0 np.random.seed(args.seed) torch.random.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.nocuda: device = torch.device("cpu") print('device is {}'.format(device)) enc = GPT2Tokenizer.from_pretrained(args.model_path) model = GPT2LMHeadModel.from_pretrained(args.model_path) model.to(device) model.eval() if args.length == -1: args.length = model.config.n_ctx // 2 elif args.length > model.config.n_ctx: raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) #while True: generated = 0 for _ in range(10): context_tokens = [] if not args.unconditional: #raw_text = input("Model prompt >>> ") raw_text = args.cond_text while not raw_text: print('Prompt should not be empty!') raw_text = input("Model prompt >>> ") context_tokens = enc.encode(raw_text) for _ in range(args.nsamples // args.batch_size): out = sample_sequence( model=model, length=args.length, context=context_tokens, start_token=None, batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device ) #out = out[:, len(context_tokens):].tolist() out = out[:, 0:].tolist() for i in range(args.batch_size): generated += 1 text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text) if args.output: filepath = os.path.join(args.output, "generated_{}.txt".format(generated)) with open(filepath, "w") as f: f.write(text) # print("=" * 80) if args.unconditional: generated = 0 for _ in range(args.nsamples // args.batch_size): out = sample_sequence( model=model, length=args.length, context=None, start_token=enc.encoder['<|endoftext|>'], batch_size=args.batch_size, temperature=args.temperature, top_k=args.top_k, device=device ) out = out[:,1:].tolist() for i in range(args.batch_size): generated += 1 text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text) #print("=" * 80) if args.unconditional: break if __name__ == '__main__': run_model()