Spaces:
Runtime error
Runtime error
#! /usr/bin/env python3 | |
# coding=utf-8 | |
# Copyright (c) 2019 Uber Technologies, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Example command with bag of words: | |
python run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 | |
Example command with discriminator: | |
python run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 | |
""" | |
import argparse | |
import json | |
from operator import add | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from pplm_classification_head import ClassificationHead | |
from torch import nn | |
from tqdm import trange | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from transformers.file_utils import cached_path | |
PPLM_BOW = 1 | |
PPLM_DISCRIM = 2 | |
PPLM_BOW_DISCRIM = 3 | |
SMALL_CONST = 1e-15 | |
BIG_CONST = 1e10 | |
BAG_OF_WORDS_ARCHIVE_MAP = { | |
"legal": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt", | |
"military": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt", | |
"politics": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt", | |
"religion": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt", | |
"science": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt", | |
"space": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt", | |
"technology": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt", | |
} | |
DISCRIMINATOR_MODELS_PARAMS = { | |
"clickbait": { | |
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt", | |
"class_size": 2, | |
"embed_size": 1024, | |
"class_vocab": {"non_clickbait": 0, "clickbait": 1}, | |
"default_class": 1, | |
"pretrained_model": "gpt2-medium", | |
}, | |
"sentiment": { | |
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt", | |
"class_size": 5, | |
"embed_size": 1024, | |
"class_vocab": {"very_positive": 2, "very_negative": 3}, | |
"default_class": 3, | |
"pretrained_model": "gpt2-medium", | |
}, | |
} | |
def top_k_filter(logits, k, probs=False): | |
""" | |
Masks everything but the k top entries as -infinity (1e10). | |
Used to mask logits such that e^-infinity -> 0 won't contribute to the | |
sum of the denominator. | |
""" | |
if k == 0: | |
return logits | |
else: | |
values = torch.topk(logits, k)[0] | |
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) | |
if probs: | |
return torch.where(logits < batch_mins, torch.ones_like(logits) * 0.0, logits) | |
return torch.where(logits < batch_mins, torch.ones_like(logits) * -BIG_CONST, logits) | |
def perturb_past( | |
past, | |
model, | |
last, | |
unpert_past=None, | |
unpert_logits=None, | |
accumulated_hidden=None, | |
grad_norms=None, | |
stepsize=0.01, | |
one_hot_bows_vectors=None, | |
classifier=None, | |
class_label=None, | |
loss_type=0, | |
num_iterations=3, | |
horizon_length=1, | |
window_length=0, | |
decay=False, | |
gamma=1.5, | |
kl_scale=0.01, | |
device="cuda", | |
): | |
# Generate inital perturbed past | |
grad_accumulator = [(np.zeros(p.shape).astype("float32")) for p in past] | |
if accumulated_hidden is None: | |
accumulated_hidden = 0 | |
if decay: | |
decay_mask = torch.arange(0.0, 1.0 + SMALL_CONST, 1.0 / (window_length))[1:] | |
else: | |
decay_mask = 1.0 | |
# TODO fix this comment (SUMANTH) | |
# Generate a mask is gradient perturbated is based on a past window | |
_, _, _, curr_length, _ = past[0].shape | |
if curr_length > window_length and window_length > 0: | |
ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:]) | |
zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:]) | |
ones_mask = torch.ones(ones_key_val_shape) | |
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) | |
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) | |
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), dim=-2).to(device) | |
else: | |
window_mask = torch.ones_like(past[0]).to(device) | |
# accumulate perturbations for num_iterations | |
loss_per_iter = [] | |
new_accumulated_hidden = None | |
for i in range(num_iterations): | |
print("Iteration ", i + 1) | |
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] | |
# make sure p_.grad is not None | |
for p_ in curr_perturbation: | |
p_.retain_grad() | |
# Compute hidden using perturbed past | |
perturbed_past = list(map(add, past, curr_perturbation)) | |
_, _, _, curr_length, _ = curr_perturbation[0].shape | |
lm_output = model(last, past_key_values=perturbed_past) | |
all_logits, all_hidden = lm_output["logits"], lm_output["hidden_states"] | |
hidden = all_hidden[-1] | |
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, dim=1).detach() | |
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth) | |
logits = all_logits[:, -1, :] | |
probs = nn.functional.softmax(logits, dim=-1) | |
loss = 0.0 | |
loss_list = [] | |
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: | |
for one_hot_bow in one_hot_bows_vectors: | |
bow_logits = torch.mm(probs, torch.t(one_hot_bow)) | |
bow_loss = -torch.log(torch.sum(bow_logits)) | |
loss += bow_loss | |
loss_list.append(bow_loss) | |
print(" pplm_bow_loss:", loss.data.cpu().numpy()) | |
if loss_type == 2 or loss_type == 3: | |
ce_loss = nn.CrossEntropyLoss() | |
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth) | |
curr_unpert_past = unpert_past | |
curr_probs = torch.unsqueeze(probs, dim=1) | |
wte = model.resize_token_embeddings() | |
for _ in range(horizon_length): | |
inputs_embeds = torch.matmul(curr_probs, wte.weight.data) | |
lm_output = model(past_key_values=curr_unpert_past, inputs_embeds=inputs_embeds) | |
curr_all_logits, curr_unpert_past, curr_all_hidden = ( | |
lm_output["logits"], | |
lm_output["past_key_values"], | |
lm_output["hidden_states"], | |
) | |
curr_logits = curr_all_logits[:, -1, :] | |
curr_probs = nn.functional.softmax(curr_logits, dim=-1) | |
curr_probs = torch.unsqueeze(curr_probs, dim=1) | |
curr_hidden = curr_all_hidden[-1] | |
new_accumulated_hidden = new_accumulated_hidden + torch.sum(curr_hidden, dim=1) | |
prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length)) | |
label = torch.tensor(prediction.shape[0] * [class_label], device=device, dtype=torch.long) | |
discrim_loss = ce_loss(prediction, label) | |
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) | |
loss += discrim_loss | |
loss_list.append(discrim_loss) | |
kl_loss = 0.0 | |
if kl_scale > 0.0: | |
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1) | |
unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() | |
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() | |
corrected_probs = probs + correction.detach() | |
kl_loss = kl_scale * ((corrected_probs * (corrected_probs / unpert_probs).log()).sum()) | |
print(" kl_loss", kl_loss.data.cpu().numpy()) | |
loss += kl_loss | |
loss_per_iter.append(loss.data.cpu().numpy()) | |
print(" pplm_loss", (loss - kl_loss).data.cpu().numpy()) | |
# compute gradients | |
loss.backward() | |
# calculate gradient norms | |
if grad_norms is not None and loss_type == PPLM_BOW: | |
grad_norms = [ | |
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) | |
for index, p_ in enumerate(curr_perturbation) | |
] | |
else: | |
grad_norms = [ | |
(torch.norm(p_.grad * window_mask) + SMALL_CONST) for index, p_ in enumerate(curr_perturbation) | |
] | |
# normalize gradients | |
grad = [ | |
-stepsize * (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy() | |
for index, p_ in enumerate(curr_perturbation) | |
] | |
# accumulate gradient | |
grad_accumulator = list(map(add, grad, grad_accumulator)) | |
# reset gradients, just to make sure | |
for p_ in curr_perturbation: | |
p_.grad.data.zero_() | |
# removing past from the graph | |
new_past = [] | |
for p_ in past: | |
new_past.append(p_.detach()) | |
past = new_past | |
# apply the accumulated perturbations to the past | |
grad_accumulator = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator] | |
pert_past = list(map(add, past, grad_accumulator)) | |
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter | |
def get_classifier( | |
name: Optional[str], class_label: Union[str, int], device: str | |
) -> Tuple[Optional[ClassificationHead], Optional[int]]: | |
if name is None: | |
return None, None | |
params = DISCRIMINATOR_MODELS_PARAMS[name] | |
classifier = ClassificationHead(class_size=params["class_size"], embed_size=params["embed_size"]).to(device) | |
if "url" in params: | |
resolved_archive_file = cached_path(params["url"]) | |
elif "path" in params: | |
resolved_archive_file = params["path"] | |
else: | |
raise ValueError("Either url or path have to be specified in the discriminator model parameters") | |
classifier.load_state_dict(torch.load(resolved_archive_file, map_location=device)) | |
classifier.eval() | |
if isinstance(class_label, str): | |
if class_label in params["class_vocab"]: | |
label_id = params["class_vocab"][class_label] | |
else: | |
label_id = params["default_class"] | |
print("class_label {} not in class_vocab".format(class_label)) | |
print("available values are: {}".format(params["class_vocab"])) | |
print("using default class {}".format(label_id)) | |
elif isinstance(class_label, int): | |
if class_label in set(params["class_vocab"].values()): | |
label_id = class_label | |
else: | |
label_id = params["default_class"] | |
print("class_label {} not in class_vocab".format(class_label)) | |
print("available values are: {}".format(params["class_vocab"])) | |
print("using default class {}".format(label_id)) | |
else: | |
label_id = params["default_class"] | |
return classifier, label_id | |
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> List[List[List[int]]]: | |
bow_indices = [] | |
for id_or_path in bag_of_words_ids_or_paths: | |
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: | |
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path]) | |
else: | |
filepath = id_or_path | |
with open(filepath, "r") as f: | |
words = f.read().strip().split("\n") | |
bow_indices.append([tokenizer.encode(word.strip(), add_prefix_space=True) for word in words]) | |
return bow_indices | |
def build_bows_one_hot_vectors(bow_indices, tokenizer, device="cuda"): | |
if bow_indices is None: | |
return None | |
one_hot_bows_vectors = [] | |
for single_bow in bow_indices: | |
single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) | |
single_bow = torch.tensor(single_bow).to(device) | |
num_words = single_bow.shape[0] | |
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device) | |
one_hot_bow.scatter_(1, single_bow, 1) | |
one_hot_bows_vectors.append(one_hot_bow) | |
return one_hot_bows_vectors | |
def full_text_generation( | |
model, | |
tokenizer, | |
context=None, | |
num_samples=1, | |
device="cuda", | |
bag_of_words=None, | |
discrim=None, | |
class_label=None, | |
length=100, | |
stepsize=0.02, | |
temperature=1.0, | |
top_k=10, | |
sample=False, | |
num_iterations=3, | |
grad_length=10000, | |
horizon_length=1, | |
window_length=0, | |
decay=False, | |
gamma=1.5, | |
gm_scale=0.9, | |
kl_scale=0.01, | |
repetition_penalty=1.0, | |
**kwargs, | |
): | |
classifier, class_id = get_classifier(discrim, class_label, device) | |
bow_indices = [] | |
if bag_of_words: | |
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer) | |
if bag_of_words and classifier: | |
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") | |
loss_type = PPLM_BOW_DISCRIM | |
elif bag_of_words: | |
loss_type = PPLM_BOW | |
print("Using PPLM-BoW") | |
elif classifier is not None: | |
loss_type = PPLM_DISCRIM | |
print("Using PPLM-Discrim") | |
else: | |
raise Exception("Specify either a bag of words or a discriminator") | |
unpert_gen_tok_text, _, _ = generate_text_pplm( | |
model=model, | |
tokenizer=tokenizer, | |
context=context, | |
device=device, | |
length=length, | |
sample=sample, | |
perturb=False, | |
repetition_penalty=repetition_penalty, | |
) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
pert_gen_tok_texts = [] | |
discrim_losses = [] | |
losses_in_time = [] | |
for i in range(num_samples): | |
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( | |
model=model, | |
tokenizer=tokenizer, | |
context=context, | |
device=device, | |
perturb=True, | |
bow_indices=bow_indices, | |
classifier=classifier, | |
class_label=class_id, | |
loss_type=loss_type, | |
length=length, | |
stepsize=stepsize, | |
temperature=temperature, | |
top_k=top_k, | |
sample=sample, | |
num_iterations=num_iterations, | |
grad_length=grad_length, | |
horizon_length=horizon_length, | |
window_length=window_length, | |
decay=decay, | |
gamma=gamma, | |
gm_scale=gm_scale, | |
kl_scale=kl_scale, | |
repetition_penalty=repetition_penalty, | |
) | |
pert_gen_tok_texts.append(pert_gen_tok_text) | |
if classifier is not None: | |
discrim_losses.append(discrim_loss.data.cpu().numpy()) | |
losses_in_time.append(loss_in_time) | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time | |
def generate_text_pplm( | |
model, | |
tokenizer, | |
context=None, | |
past=None, | |
device="cuda", | |
perturb=True, | |
bow_indices=None, | |
classifier=None, | |
class_label=None, | |
loss_type=0, | |
length=100, | |
stepsize=0.02, | |
temperature=1.0, | |
top_k=10, | |
sample=False, | |
num_iterations=3, | |
grad_length=10000, | |
horizon_length=1, | |
window_length=0, | |
decay=False, | |
gamma=1.5, | |
gm_scale=0.9, | |
kl_scale=0.01, | |
repetition_penalty=1.0, | |
): | |
output_so_far = None | |
if context: | |
context_t = torch.tensor(context, device=device, dtype=torch.long) | |
while len(context_t.shape) < 2: | |
context_t = context_t.unsqueeze(0) | |
output_so_far = context_t | |
# collect one hot vectors for bags of words | |
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, device) | |
grad_norms = None | |
last = None | |
unpert_discrim_loss = 0 | |
loss_in_time = [] | |
for i in trange(length, ascii=True): | |
# Get past/probs for current output, except for last word | |
# Note that GPT takes 2 inputs: past + current_token | |
# run model forward to obtain unperturbed | |
if past is None and output_so_far is not None: | |
last = output_so_far[:, -1:] | |
if output_so_far.shape[1] > 1: | |
past = model(output_so_far[:, :-1])["past_key_values"] | |
lm_output = model(output_so_far) | |
unpert_logits, unpert_past, unpert_all_hidden = ( | |
lm_output["logits"], | |
lm_output["past_key_values"], | |
lm_output["hidden_states"], | |
) | |
unpert_last_hidden = unpert_all_hidden[-1] | |
# check if we are abowe grad max length | |
if i >= grad_length: | |
current_stepsize = stepsize * 0 | |
else: | |
current_stepsize = stepsize | |
# modify the past if necessary | |
if not perturb or num_iterations == 0: | |
pert_past = past | |
else: | |
accumulated_hidden = unpert_last_hidden[:, :-1, :] | |
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) | |
if past is not None: | |
pert_past, _, grad_norms, loss_this_iter = perturb_past( | |
past, | |
model, | |
last, | |
unpert_past=unpert_past, | |
unpert_logits=unpert_logits, | |
accumulated_hidden=accumulated_hidden, | |
grad_norms=grad_norms, | |
stepsize=current_stepsize, | |
one_hot_bows_vectors=one_hot_bows_vectors, | |
classifier=classifier, | |
class_label=class_label, | |
loss_type=loss_type, | |
num_iterations=num_iterations, | |
horizon_length=horizon_length, | |
window_length=window_length, | |
decay=decay, | |
gamma=gamma, | |
kl_scale=kl_scale, | |
device=device, | |
) | |
loss_in_time.append(loss_this_iter) | |
else: | |
pert_past = past | |
lm_output = model(last, past_key_values=pert_past) | |
pert_logits, past = ( | |
lm_output["logits"], | |
lm_output["past_key_values"], | |
) | |
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST | |
for token_idx in set(output_so_far[0].tolist()): | |
if pert_logits[0, token_idx] < 0: | |
pert_logits[0, token_idx] *= repetition_penalty | |
else: | |
pert_logits[0, token_idx] /= repetition_penalty | |
pert_probs = nn.functional.softmax(pert_logits, dim=-1) | |
if classifier is not None: | |
ce_loss = nn.CrossEntropyLoss() | |
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) | |
label = torch.tensor([class_label], device=device, dtype=torch.long) | |
unpert_discrim_loss = ce_loss(prediction, label) | |
print("unperturbed discrim loss", unpert_discrim_loss.data.cpu().numpy()) | |
else: | |
unpert_discrim_loss = 0 | |
# Fuse the modified model and original model | |
if perturb: | |
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1) | |
pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST | |
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST | |
# rescale | |
if torch.sum(pert_probs) <= 1: | |
pert_probs = pert_probs / torch.sum(pert_probs) | |
else: | |
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST | |
pert_probs = nn.functional.softmax(pert_logits, dim=-1) | |
# sample or greedy | |
if sample: | |
last = torch.multinomial(pert_probs, num_samples=1) | |
else: | |
_, last = torch.topk(pert_probs, k=1, dim=-1) | |
# update context/output_so_far appending the new token | |
output_so_far = last if output_so_far is None else torch.cat((output_so_far, last), dim=1) | |
print(tokenizer.decode(output_so_far.tolist()[0])) | |
return output_so_far, unpert_discrim_loss, loss_in_time | |
def set_generic_model_params(discrim_weights, discrim_meta): | |
if discrim_weights is None: | |
raise ValueError("When using a generic discriminator, discrim_weights need to be specified") | |
if discrim_meta is None: | |
raise ValueError("When using a generic discriminator, discrim_meta need to be specified") | |
with open(discrim_meta, "r") as discrim_meta_file: | |
meta = json.load(discrim_meta_file) | |
meta["path"] = discrim_weights | |
DISCRIMINATOR_MODELS_PARAMS["generic"] = meta | |
def run_pplm_example( | |
pretrained_model="gpt2-medium", | |
cond_text="", | |
uncond=False, | |
num_samples=1, | |
bag_of_words=None, | |
discrim=None, | |
discrim_weights=None, | |
discrim_meta=None, | |
class_label=-1, | |
length=100, | |
stepsize=0.02, | |
temperature=1.0, | |
top_k=10, | |
sample=False, | |
num_iterations=3, | |
grad_length=10000, | |
horizon_length=1, | |
window_length=0, | |
decay=False, | |
gamma=1.5, | |
gm_scale=0.9, | |
kl_scale=0.01, | |
seed=0, | |
no_cuda=False, | |
colorama=False, | |
repetition_penalty=1.0, | |
): | |
# set Random seed | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
# set the device | |
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" | |
if discrim == "generic": | |
set_generic_model_params(discrim_weights, discrim_meta) | |
if discrim is not None: | |
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim]["pretrained_model"] | |
print("discrim = {}, pretrained_model set to discriminator's = {}".format(discrim, pretrained_model)) | |
# load pretrained model | |
model = GPT2LMHeadModel.from_pretrained(pretrained_model, output_hidden_states=True) | |
model.to(device) | |
model.eval() | |
# load tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) | |
# Freeze GPT-2 weights | |
for param in model.parameters(): | |
param.requires_grad = False | |
# figure out conditioning text | |
if uncond: | |
tokenized_cond_text = tokenizer.encode([tokenizer.bos_token]) | |
else: | |
raw_text = cond_text | |
while not raw_text: | |
print("Did you forget to add `--cond_text`? ") | |
raw_text = input("Model prompt >>> ") | |
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text) | |
print("= Prefix of sentence =") | |
print(tokenizer.decode(tokenized_cond_text)) | |
print() | |
# generate unperturbed and perturbed texts | |
# full_text_generation returns: | |
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time | |
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( | |
model=model, | |
tokenizer=tokenizer, | |
context=tokenized_cond_text, | |
device=device, | |
num_samples=num_samples, | |
bag_of_words=bag_of_words, | |
discrim=discrim, | |
class_label=class_label, | |
length=length, | |
stepsize=stepsize, | |
temperature=temperature, | |
top_k=top_k, | |
sample=sample, | |
num_iterations=num_iterations, | |
grad_length=grad_length, | |
horizon_length=horizon_length, | |
window_length=window_length, | |
decay=decay, | |
gamma=gamma, | |
gm_scale=gm_scale, | |
kl_scale=kl_scale, | |
repetition_penalty=repetition_penalty, | |
) | |
# untokenize unperturbed text | |
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0]) | |
print("=" * 80) | |
print("= Unperturbed generated text =") | |
print(unpert_gen_text) | |
print() | |
generated_texts = [] | |
bow_word_ids = set() | |
if bag_of_words and colorama: | |
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), tokenizer) | |
for single_bow_list in bow_indices: | |
# filtering all words in the list composed of more than 1 token | |
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) | |
# w[0] because we are sure w has only 1 item because previous fitler | |
bow_word_ids.update(w[0] for w in filtered) | |
# iterate through the perturbed texts | |
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): | |
try: | |
# untokenize unperturbed text | |
if colorama: | |
import colorama | |
pert_gen_text = "" | |
for word_id in pert_gen_tok_text.tolist()[0]: | |
if word_id in bow_word_ids: | |
pert_gen_text += "{}{}{}".format( | |
colorama.Fore.RED, | |
tokenizer.decode([word_id]), | |
colorama.Style.RESET_ALL, | |
) | |
else: | |
pert_gen_text += tokenizer.decode([word_id]) | |
else: | |
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0]) | |
print("= Perturbed generated text {} =".format(i + 1)) | |
print(pert_gen_text) | |
print() | |
except Exception as exc: | |
print("Ignoring error while generating perturbed text:", exc) | |
# keep the prefix, perturbed seq, original seq for each index | |
generated_texts.append((tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)) | |
return | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--pretrained_model", | |
"-M", | |
type=str, | |
default="gpt2-medium", | |
help="pretrained model name or path to local checkpoint", | |
) | |
parser.add_argument("--cond_text", type=str, default="The lake", help="Prefix texts to condition on") | |
parser.add_argument("--uncond", action="store_true", help="Generate from end-of-text as prefix") | |
parser.add_argument( | |
"--num_samples", | |
type=int, | |
default=1, | |
help="Number of samples to generate from the modified latents", | |
) | |
parser.add_argument( | |
"--bag_of_words", | |
"-B", | |
type=str, | |
default=None, | |
help=( | |
"Bags of words used for PPLM-BoW. " | |
"Either a BOW id (see list in code) or a filepath. " | |
"Multiple BoWs separated by ;" | |
), | |
) | |
parser.add_argument( | |
"--discrim", | |
"-D", | |
type=str, | |
default=None, | |
choices=("clickbait", "sentiment", "toxicity", "generic"), | |
help="Discriminator to use", | |
) | |
parser.add_argument( | |
"--discrim_weights", | |
type=str, | |
default=None, | |
help="Weights for the generic discriminator", | |
) | |
parser.add_argument( | |
"--discrim_meta", | |
type=str, | |
default=None, | |
help="Meta information for the generic discriminator", | |
) | |
parser.add_argument( | |
"--class_label", | |
type=int, | |
default=-1, | |
help="Class label used for the discriminator", | |
) | |
parser.add_argument("--length", type=int, default=100) | |
parser.add_argument("--stepsize", type=float, default=0.02) | |
parser.add_argument("--temperature", type=float, default=1.0) | |
parser.add_argument("--top_k", type=int, default=10) | |
parser.add_argument("--sample", action="store_true", help="Generate from end-of-text as prefix") | |
parser.add_argument("--num_iterations", type=int, default=3) | |
parser.add_argument("--grad_length", type=int, default=10000) | |
parser.add_argument( | |
"--window_length", | |
type=int, | |
default=0, | |
help="Length of past which is being optimized; 0 corresponds to infinite window length", | |
) | |
parser.add_argument( | |
"--horizon_length", | |
type=int, | |
default=1, | |
help="Length of future to optimize over", | |
) | |
parser.add_argument("--decay", action="store_true", help="whether to decay or not") | |
parser.add_argument("--gamma", type=float, default=1.5) | |
parser.add_argument("--gm_scale", type=float, default=0.9) | |
parser.add_argument("--kl_scale", type=float, default=0.01) | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument("--no_cuda", action="store_true", help="no cuda") | |
parser.add_argument("--colorama", action="store_true", help="colors keywords") | |
parser.add_argument( | |
"--repetition_penalty", | |
type=float, | |
default=1.0, | |
help="Penalize repetition. More than 1.0 -> less repetition", | |
) | |
args = parser.parse_args() | |
run_pplm_example(**vars(args)) | |