|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Example command with bag of words: |
|
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 |
|
|
|
Example command with discriminator: |
|
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 |
|
""" |
|
|
|
import gradio as gr |
|
import argparse |
|
import json |
|
from operator import add |
|
from typing import List, Optional, Tuple, Union |
|
from random import choice, randint |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from tqdm import trange |
|
from transformers import GPT2Tokenizer |
|
from transformers.file_utils import cached_path |
|
from transformers.modeling_gpt2 import GPT2LMHeadModel |
|
|
|
from pplm_classification_head import ClassificationHead |
|
|
|
PPLM_BOW = 1 |
|
PPLM_DISCRIM = 2 |
|
PPLM_BOW_DISCRIM = 3 |
|
SMALL_CONST = 1e-15 |
|
BIG_CONST = 1e10 |
|
|
|
QUIET = 0 |
|
REGULAR = 1 |
|
VERBOSE = 2 |
|
VERY_VERBOSE = 3 |
|
VERBOSITY_LEVELS = { |
|
'quiet': QUIET, |
|
'regular': REGULAR, |
|
'verbose': VERBOSE, |
|
'very_verbose': VERY_VERBOSE, |
|
} |
|
|
|
BAG_OF_WORDS_ARCHIVE_MAP = { |
|
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt", |
|
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt", |
|
'monsters': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/monsters.txt", |
|
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt", |
|
'positive_words': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/positive_words.txt", |
|
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt", |
|
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt", |
|
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt", |
|
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt", |
|
} |
|
|
|
DISCRIMINATOR_MODELS_PARAMS = { |
|
"clickbait": { |
|
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt", |
|
"class_size": 2, |
|
"embed_size": 1024, |
|
"class_vocab": {"non_clickbait": 0, "clickbait": 1}, |
|
"default_class": 1, |
|
"pretrained_model": "gpt2-medium", |
|
}, |
|
"sentiment": { |
|
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt", |
|
"class_size": 5, |
|
"embed_size": 1024, |
|
"class_vocab": {"very_positive": 2, "very_negative": 3}, |
|
"default_class": 3, |
|
"pretrained_model": "gpt2-medium", |
|
}, |
|
"3_PerSoothe": { |
|
"path": "/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_opt_lowlr_medgpt/3_PerSoothe_classifier_head_epoch_10.pt", |
|
"class_size": 3, |
|
"embed_size": 1024, |
|
"class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
|
"default_class": 2, |
|
"pretrained_model": "microsoft/DialoGPT-medium", |
|
}, |
|
"3_PerSoothe_eot": { |
|
"path": "/content/drive/Shareddrives/COS_IW04_ZL/COSIW04/Discriminators/3_class_opt_eot_lowlr_medgpt/3_PerSoothe_classifier_head_epoch_10.pt", |
|
"class_size": 3, |
|
"embed_size": 1024, |
|
"class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
|
"default_class": 2, |
|
"pretrained_model": "microsoft/DialoGPT-medium", |
|
}, |
|
"3_PerSoothe_lrg": { |
|
"class_size": 3, |
|
"embed_size": 1280, |
|
"class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
|
"default_class": 2, |
|
"pretrained_model": "microsoft/DialoGPT-large", |
|
}, |
|
"3_PerSoothe_med": { |
|
"class_size": 3, |
|
"embed_size": 1024, |
|
"class_vocab": {"soothes": 0, "neutral": 1, "worsens": 2}, |
|
"default_class": 2, |
|
"pretrained_model": "microsoft/DialoGPT-medium", |
|
}, |
|
} |
|
|
|
|
|
def to_var(x, requires_grad=False, volatile=False, device='cuda'): |
|
if torch.cuda.is_available() and device == 'cuda': |
|
x = x.cuda() |
|
elif device != 'cuda': |
|
x = x.to(device) |
|
return Variable(x, requires_grad=requires_grad, volatile=volatile) |
|
|
|
|
|
def top_k_filter(logits, k, probs=False): |
|
""" |
|
Masks everything but the k top entries as -infinity (1e10). |
|
Used to mask logits such that e^-infinity -> 0 won't contribute to the |
|
sum of the denominator. |
|
""" |
|
if k == 0: |
|
return logits |
|
else: |
|
values = torch.topk(logits, k)[0] |
|
batch_mins = values[:, -1].view(-1, 1).expand_as(logits) |
|
if probs: |
|
return torch.where(logits < batch_mins, |
|
torch.ones_like(logits) * 0.0, logits) |
|
return torch.where(logits < batch_mins, |
|
torch.ones_like(logits) * -BIG_CONST, |
|
logits) |
|
|
|
|
|
def perturb_past( |
|
past, |
|
model, |
|
last, |
|
unpert_past =None, |
|
unpert_logits=None, |
|
accumulated_hidden=None, |
|
grad_norms=None, |
|
stepsize=0.01, |
|
one_hot_bows_vectors=None, |
|
classifier=None, |
|
class_label=None, |
|
loss_type=0, |
|
num_iterations=3, |
|
horizon_length=1, |
|
window_length=0, |
|
decay=False, |
|
gamma=1.5, |
|
kl_scale=0.01, |
|
device='cuda', |
|
verbosity_level=REGULAR |
|
): |
|
|
|
grad_accumulator = [ |
|
(np.zeros(p.shape).astype("float32")) |
|
for p in past |
|
] |
|
|
|
if accumulated_hidden is None: |
|
accumulated_hidden = 0 |
|
|
|
if decay: |
|
decay_mask = torch.arange( |
|
0., |
|
1.0 + SMALL_CONST, |
|
1.0 / (window_length) |
|
)[1:] |
|
else: |
|
decay_mask = 1.0 |
|
|
|
|
|
|
|
_, _, _, curr_length, _ = past[0].shape |
|
|
|
if curr_length > window_length and window_length > 0: |
|
ones_key_val_shape = ( |
|
tuple(past[0].shape[:-2]) |
|
+ tuple([window_length]) |
|
+ tuple(past[0].shape[-1:]) |
|
) |
|
|
|
zeros_key_val_shape = ( |
|
tuple(past[0].shape[:-2]) |
|
+ tuple([curr_length - window_length]) |
|
+ tuple(past[0].shape[-1:]) |
|
) |
|
|
|
ones_mask = torch.ones(ones_key_val_shape) |
|
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) |
|
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) |
|
|
|
window_mask = torch.cat( |
|
(ones_mask, torch.zeros(zeros_key_val_shape)), |
|
dim=-2 |
|
).to(device) |
|
else: |
|
window_mask = torch.ones_like(past[0]).to(device) |
|
|
|
|
|
loss_per_iter = [] |
|
new_accumulated_hidden = None |
|
for i in range(num_iterations): |
|
if verbosity_level >= VERBOSE: |
|
print("Iteration ", i + 1) |
|
curr_perturbation = [ |
|
to_var(torch.from_numpy(p_), requires_grad=True, device=device) |
|
for p_ in grad_accumulator |
|
] |
|
|
|
|
|
perturbed_past = list(map(add, past, curr_perturbation)) |
|
_, _, _, curr_length, _ = curr_perturbation[0].shape |
|
all_logits, _, all_hidden = model(last, past_key_values=perturbed_past) |
|
hidden = all_hidden[-1] |
|
new_accumulated_hidden = accumulated_hidden + torch.sum( |
|
hidden, |
|
dim=1 |
|
).detach() |
|
|
|
logits = all_logits[:, -1, :] |
|
probs = F.softmax(logits, dim=-1) |
|
|
|
loss = 0.0 |
|
loss_list = [] |
|
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM: |
|
for one_hot_bow in one_hot_bows_vectors: |
|
bow_logits = torch.mm(probs, torch.t(one_hot_bow)) |
|
bow_loss = -torch.log(torch.sum(bow_logits)) |
|
loss += bow_loss |
|
loss_list.append(bow_loss) |
|
if verbosity_level >= VERY_VERBOSE: |
|
print(" pplm_bow_loss:", loss.data.cpu().numpy()) |
|
|
|
if loss_type == PPLM_DISCRIM or loss_type == PPLM_BOW_DISCRIM: |
|
ce_loss = torch.nn.CrossEntropyLoss() |
|
|
|
curr_unpert_past = unpert_past |
|
curr_probs = torch.unsqueeze(probs, dim=1) |
|
wte = model.resize_token_embeddings() |
|
for _ in range(horizon_length): |
|
inputs_embeds = torch.matmul(curr_probs, wte.weight.data) |
|
_, curr_unpert_past, curr_all_hidden = model( |
|
past_key_values=curr_unpert_past, |
|
inputs_embeds=inputs_embeds |
|
) |
|
curr_hidden = curr_all_hidden[-1] |
|
new_accumulated_hidden = new_accumulated_hidden + torch.sum( |
|
curr_hidden, dim=1) |
|
|
|
prediction = classifier(new_accumulated_hidden / |
|
(curr_length + 1 + horizon_length)) |
|
|
|
label = torch.tensor(prediction.shape[0] * [class_label], |
|
device=device, |
|
dtype=torch.long) |
|
discrim_loss = ce_loss(prediction, label) |
|
if verbosity_level >= VERY_VERBOSE: |
|
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) |
|
loss += discrim_loss |
|
loss_list.append(discrim_loss) |
|
|
|
kl_loss = 0.0 |
|
if kl_scale > 0.0: |
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) |
|
unpert_probs = ( |
|
unpert_probs + SMALL_CONST * |
|
(unpert_probs <= SMALL_CONST).float().to(device).detach() |
|
) |
|
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to( |
|
device).detach() |
|
corrected_probs = probs + correction.detach() |
|
kl_loss = kl_scale * ( |
|
(corrected_probs * (corrected_probs / unpert_probs).log()).sum() |
|
) |
|
if verbosity_level >= VERY_VERBOSE: |
|
print(' kl_loss', kl_loss.data.cpu().numpy()) |
|
loss += kl_loss |
|
|
|
loss_per_iter.append(loss.data.cpu().numpy()) |
|
if verbosity_level >= VERBOSE: |
|
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
if grad_norms is not None and loss_type == PPLM_BOW: |
|
grad_norms = [ |
|
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) |
|
for index, p_ in enumerate(curr_perturbation) |
|
] |
|
else: |
|
grad_norms = [ |
|
(torch.norm(p_.grad * window_mask) + SMALL_CONST) |
|
for index, p_ in enumerate(curr_perturbation) |
|
] |
|
|
|
|
|
grad = [ |
|
-stepsize * |
|
(p_.grad * window_mask / grad_norms[ |
|
index] ** gamma).data.cpu().numpy() |
|
for index, p_ in enumerate(curr_perturbation) |
|
] |
|
|
|
|
|
grad_accumulator = list(map(add, grad, grad_accumulator)) |
|
|
|
|
|
for p_ in curr_perturbation: |
|
p_.grad.data.zero_() |
|
|
|
|
|
new_past = [] |
|
for p_ in past: |
|
new_past.append(p_.detach()) |
|
past = new_past |
|
|
|
|
|
grad_accumulator = [ |
|
to_var(torch.from_numpy(p_), requires_grad=True, device=device) |
|
for p_ in grad_accumulator |
|
] |
|
pert_past = list(map(add, past, grad_accumulator)) |
|
|
|
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter |
|
|
|
|
|
def get_classifier( |
|
name: Optional[str], |
|
class_label: Union[str, int], |
|
device: str, |
|
verbosity_level: int = REGULAR, |
|
fp: str = None, |
|
is_deep: bool= False, |
|
is_deeper: bool=False, |
|
) -> Tuple[Optional[ClassificationHead], Optional[int]]: |
|
if name is None: |
|
return None, None |
|
|
|
params = DISCRIMINATOR_MODELS_PARAMS[name] |
|
classifier = ClassificationHead( |
|
class_size=params['class_size'], |
|
embed_size=params['embed_size'], |
|
is_deep=is_deep, |
|
is_deeper=is_deeper |
|
).to(device) |
|
if "url" in params: |
|
resolved_archive_file = cached_path(params["url"]) |
|
elif "path" in params: |
|
resolved_archive_file = params["path"] |
|
elif fp != None: |
|
resolved_archive_file = fp |
|
else: |
|
raise ValueError("Either url or path have to be specified " |
|
"in the discriminator model parameters") |
|
classifier.load_state_dict( |
|
torch.load(resolved_archive_file, map_location=device)) |
|
classifier.eval() |
|
|
|
if isinstance(class_label, str): |
|
if class_label in params["class_vocab"]: |
|
label_id = params["class_vocab"][class_label] |
|
else: |
|
label_id = params["default_class"] |
|
if verbosity_level >= REGULAR: |
|
print("class_label {} not in class_vocab".format(class_label)) |
|
print("available values are: {}".format(params["class_vocab"])) |
|
print("using default class {}".format(label_id)) |
|
|
|
elif isinstance(class_label, int): |
|
if class_label in set(params["class_vocab"].values()): |
|
label_id = class_label |
|
else: |
|
label_id = params["default_class"] |
|
if verbosity_level >= REGULAR: |
|
print("class_label {} not in class_vocab".format(class_label)) |
|
print("available values are: {}".format(params["class_vocab"])) |
|
print("using default class {}".format(label_id)) |
|
|
|
else: |
|
label_id = params["default_class"] |
|
|
|
return classifier, label_id |
|
|
|
|
|
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ |
|
List[List[List[int]]]: |
|
bow_indices = [] |
|
for id_or_path in bag_of_words_ids_or_paths: |
|
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: |
|
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path]) |
|
else: |
|
filepath = id_or_path |
|
with open(filepath, "r") as f: |
|
words = f.read().strip().split("\n") |
|
bow_indices.append( |
|
[tokenizer.encode(word.strip(), |
|
add_prefix_space=True, |
|
add_special_tokens=False) |
|
for word in words]) |
|
return bow_indices |
|
|
|
|
|
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'): |
|
if bow_indices is None: |
|
return None |
|
|
|
one_hot_bows_vectors = [] |
|
for single_bow in bow_indices: |
|
single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) |
|
single_bow = torch.tensor(single_bow).to(device) |
|
num_words = single_bow.shape[0] |
|
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device) |
|
one_hot_bow.scatter_(1, single_bow, 1) |
|
one_hot_bows_vectors.append(one_hot_bow) |
|
return one_hot_bows_vectors |
|
|
|
|
|
def full_text_generation( |
|
model, |
|
tokenizer, |
|
context=None, |
|
num_samples=1, |
|
device="cuda", |
|
bag_of_words=None, |
|
discrim=None, |
|
class_label=None, |
|
length=100, |
|
stepsize=0.02, |
|
temperature=1.0, |
|
top_k=10, |
|
sample=True, |
|
num_iterations=3, |
|
grad_length=10000, |
|
horizon_length=1, |
|
window_length=0, |
|
decay=False, |
|
gamma=1.5, |
|
gm_scale=0.9, |
|
kl_scale=0.01, |
|
verbosity_level=REGULAR, |
|
fp=None, |
|
is_deep=False, |
|
is_deeper=False, |
|
stop_eot=False, |
|
**kwargs |
|
): |
|
classifier, class_id = get_classifier( |
|
discrim, |
|
class_label, |
|
device, |
|
REGULAR, |
|
fp, |
|
is_deep, |
|
is_deeper |
|
) |
|
|
|
bow_indices = [] |
|
if bag_of_words: |
|
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"), |
|
tokenizer) |
|
|
|
if bag_of_words and classifier: |
|
loss_type = PPLM_BOW_DISCRIM |
|
if verbosity_level >= REGULAR: |
|
print("Both PPLM-BoW and PPLM-Discrim are on. " |
|
"This is not optimized.") |
|
|
|
elif bag_of_words: |
|
loss_type = PPLM_BOW |
|
if verbosity_level >= REGULAR: |
|
print("Using PPLM-BoW") |
|
|
|
elif classifier is not None: |
|
loss_type = PPLM_DISCRIM |
|
if verbosity_level >= REGULAR: |
|
print("Using PPLM-Discrim") |
|
|
|
else: |
|
raise Exception("Specify either a bag of words or a discriminator") |
|
|
|
unpert_gen_tok_text, _, _, _ = generate_text_pplm( |
|
model=model, |
|
tokenizer=tokenizer, |
|
context=context, |
|
device=device, |
|
length=length, |
|
sample=sample, |
|
perturb=False, |
|
verbosity_level=verbosity_level, |
|
stop_eot=stop_eot |
|
) |
|
if device == 'cuda': |
|
torch.cuda.empty_cache() |
|
|
|
pert_gen_tok_texts = [] |
|
discrim_losses = [] |
|
losses_in_time = [] |
|
perplexities = [] |
|
|
|
for i in range(num_samples): |
|
pert_gen_tok_text, discrim_loss, loss_in_time, perplexity = generate_text_pplm( |
|
model=model, |
|
tokenizer=tokenizer, |
|
context=context, |
|
device=device, |
|
perturb=True, |
|
bow_indices=bow_indices, |
|
classifier=classifier, |
|
class_label=class_id, |
|
loss_type=loss_type, |
|
length=length, |
|
stepsize=stepsize, |
|
temperature=temperature, |
|
top_k=top_k, |
|
sample=sample, |
|
num_iterations=num_iterations, |
|
grad_length=grad_length, |
|
horizon_length=horizon_length, |
|
window_length=window_length, |
|
decay=decay, |
|
gamma=gamma, |
|
gm_scale=gm_scale, |
|
kl_scale=kl_scale, |
|
verbosity_level=verbosity_level, |
|
stop_eot=stop_eot |
|
) |
|
pert_gen_tok_texts.append(pert_gen_tok_text) |
|
if classifier is not None: |
|
discrim_losses.append(discrim_loss.data.cpu().numpy()) |
|
losses_in_time.append(loss_in_time) |
|
perplexities.append(perplexity) |
|
|
|
if device == 'cuda': |
|
torch.cuda.empty_cache() |
|
|
|
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time, perplexities |
|
|
|
|
|
def generate_text_pplm( |
|
model, |
|
tokenizer, |
|
context=None, |
|
past=None, |
|
device="cuda", |
|
perturb=True, |
|
bow_indices=None, |
|
classifier=None, |
|
class_label=None, |
|
loss_type=0, |
|
length=100, |
|
stepsize=0.02, |
|
temperature=1.0, |
|
top_k=10, |
|
sample=True, |
|
num_iterations=3, |
|
grad_length=10000, |
|
horizon_length=1, |
|
window_length=0, |
|
decay=False, |
|
gamma=1.5, |
|
gm_scale=0.9, |
|
kl_scale=0.01, |
|
verbosity_level=REGULAR, |
|
stop_eot=False |
|
): |
|
output_so_far = None |
|
if context: |
|
context_t = torch.tensor(context, device=device, dtype=torch.long) |
|
while len(context_t.shape) < 2: |
|
context_t = context_t.unsqueeze(0) |
|
output_so_far = context_t |
|
|
|
|
|
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer, |
|
device) |
|
|
|
grad_norms = None |
|
last = None |
|
unpert_discrim_loss = 0 |
|
loss_in_time = [] |
|
|
|
if verbosity_level >= VERBOSE: |
|
range_func = trange(length, ascii=True) |
|
else: |
|
range_func = range(length) |
|
|
|
pert_total_prob = 1 |
|
pert_times = 0 |
|
last_reps = torch.ones(50257) |
|
last_reps = last_reps.to(device) |
|
for i in range_func: |
|
|
|
|
|
|
|
|
|
if past is None and output_so_far is not None: |
|
last = output_so_far[:, -1:] |
|
if output_so_far.shape[1] > 1: |
|
_, past, _ = model(output_so_far[:, :-1]) |
|
|
|
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far) |
|
unpert_last_hidden = unpert_all_hidden[-1] |
|
|
|
|
|
if i >= grad_length: |
|
current_stepsize = stepsize * 0 |
|
else: |
|
current_stepsize = stepsize |
|
|
|
|
|
if not perturb or num_iterations == 0: |
|
pert_past = past |
|
|
|
else: |
|
accumulated_hidden = unpert_last_hidden[:, :-1, :] |
|
accumulated_hidden = torch.sum(accumulated_hidden, dim=1) |
|
|
|
if past is not None: |
|
pert_past, _, grad_norms, loss_this_iter = perturb_past( |
|
past, |
|
model, |
|
last, |
|
unpert_past=unpert_past, |
|
unpert_logits=unpert_logits, |
|
accumulated_hidden=accumulated_hidden, |
|
grad_norms=grad_norms, |
|
stepsize=current_stepsize, |
|
one_hot_bows_vectors=one_hot_bows_vectors, |
|
classifier=classifier, |
|
class_label=class_label, |
|
loss_type=loss_type, |
|
num_iterations=num_iterations, |
|
horizon_length=horizon_length, |
|
window_length=window_length, |
|
decay=decay, |
|
gamma=gamma, |
|
kl_scale=kl_scale, |
|
device=device, |
|
verbosity_level=verbosity_level |
|
) |
|
loss_in_time.append(loss_this_iter) |
|
else: |
|
pert_past = past |
|
|
|
pert_logits, past, pert_all_hidden = model(last, past_key_values=pert_past) |
|
pert_logits = pert_logits[:, -1, :] / temperature |
|
pert_probs = F.softmax(pert_logits, dim=-1) |
|
|
|
if classifier is not None: |
|
ce_loss = torch.nn.CrossEntropyLoss() |
|
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) |
|
label = torch.tensor([class_label], device=device, |
|
dtype=torch.long) |
|
unpert_discrim_loss = ce_loss(prediction, label) |
|
if verbosity_level >= VERBOSE: |
|
print( |
|
"unperturbed discrim loss", |
|
unpert_discrim_loss.data.cpu().numpy() |
|
) |
|
else: |
|
unpert_discrim_loss = 0 |
|
|
|
|
|
if perturb: |
|
|
|
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1) |
|
|
|
pert_probs = ((pert_probs ** gm_scale) * ( |
|
unpert_probs ** (1 - gm_scale))) |
|
if i < 2: |
|
pert_probs = top_k_filter(pert_probs, k=max(2, top_k), probs=True) |
|
if i == 0: pert_probs[0][50256] = 0 |
|
if i == 1: |
|
tmp = pert_probs[0][50256] |
|
pert_probs[0][50256] = 0 |
|
pert_probs[0][50256] = min(torch.max(pert_probs[0]), tmp) |
|
else: |
|
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) |
|
pert_probs = torch.div(pert_probs, last_reps) |
|
|
|
if torch.sum(pert_probs) <= 1: |
|
pert_probs = pert_probs / torch.sum(pert_probs) |
|
else: |
|
pert_logits = top_k_filter(pert_logits, k=top_k) |
|
pert_probs = F.softmax(pert_logits, dim=-1) |
|
|
|
|
|
if sample: |
|
last = torch.multinomial(pert_probs, num_samples=1) |
|
pert_total_prob = pert_total_prob * pert_probs[0][last[0][0]] |
|
else: |
|
_, last = torch.topk(pert_probs, k=1, dim=-1) |
|
last_reps[last[0][0]] = last_reps[last[0][0]] * 8 |
|
|
|
output_so_far = ( |
|
last if output_so_far is None |
|
else torch.cat((output_so_far, last), dim=1) |
|
) |
|
if verbosity_level >= REGULAR: |
|
print(tokenizer.decode(output_so_far.tolist()[0])) |
|
pert_times += 1 |
|
if last[0][0] == 50256 and stop_eot: |
|
break |
|
perplexity = (1/pert_total_prob)**(1/pert_times) |
|
return output_so_far, unpert_discrim_loss, loss_in_time, perplexity |
|
|
|
|
|
def set_generic_model_params(discrim_weights, discrim_meta): |
|
if discrim_weights is None: |
|
raise ValueError('When using a generic discriminator, ' |
|
'discrim_weights need to be specified') |
|
if discrim_meta is None: |
|
raise ValueError('When using a generic discriminator, ' |
|
'discrim_meta need to be specified') |
|
|
|
with open(discrim_meta, 'r') as discrim_meta_file: |
|
meta = json.load(discrim_meta_file) |
|
meta['path'] = discrim_weights |
|
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta |
|
|
|
|
|
pretrained_model="microsoft/DialoGPT-large" |
|
cond_text="" |
|
uncond=False |
|
num_samples=1 |
|
bag_of_words=None |
|
discrim="3_PerSoothe_lrg" |
|
discrim_weights=None |
|
discrim_meta=None |
|
class_label=0 |
|
length=100 |
|
stepsize=2.56 |
|
temperature=1.0 |
|
top_k=2 |
|
sample=True |
|
num_iterations=10 |
|
grad_length=10000 |
|
horizon_length=5 |
|
window_length=0 |
|
decay=False |
|
gamma=1.0 |
|
gm_scale=0.95 |
|
kl_scale=0.01 |
|
seed=0 |
|
no_cuda=False |
|
colorama=False |
|
verbosity="quiet" |
|
fp="./paper_code/discrim_models/persoothe_classifier.pt" |
|
model_fp="./paper_code/discrim_models/persoothe_encoder.pt" |
|
calc_perplexity=False |
|
is_deep=False |
|
is_deeper=True |
|
stop_eot=True |
|
|
|
|
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
verbosity_level = VERBOSITY_LEVELS.get(verbosity.lower(), REGULAR) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" |
|
|
|
if discrim == 'generic': |
|
set_generic_model_params(discrim_weights, discrim_meta) |
|
|
|
if discrim is not None: |
|
discriminator_pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ |
|
"pretrained_model" |
|
] |
|
if pretrained_model != discriminator_pretrained_model: |
|
pretrained_model = discriminator_pretrained_model |
|
if verbosity_level >= REGULAR: |
|
print("discrim = {}, pretrained_model set " |
|
"to discriminator's = {}".format(discrim, pretrained_model)) |
|
|
|
|
|
model = GPT2LMHeadModel.from_pretrained( |
|
pretrained_model, |
|
output_hidden_states=True |
|
) |
|
if model_fp != None and model_fp != "": |
|
model.load_state_dict(torch.load(model_fp, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) |
|
|
|
|
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
eot_token = "<|endoftext|>" |
|
|
|
def get_reply(response, history = None, in_stepsize = 2.56, in_horizon_length = 5, in_num_iterations = 10, in_top_k = 2): |
|
stepsize = in_stepsize |
|
horizon_length = int(in_horizon_length) |
|
num_iterations = int(in_num_iterations) |
|
top_k = int(in_top_k) |
|
if response.endswith(("bye", "Bye", "bye.", "Bye.", "bye!", "Bye!")): |
|
return "<div class='chatbot'>Chatbot restarted</div>", None |
|
convo_hist = (history if history != None else "How are you?<|endoftext|>") + response + eot_token |
|
|
|
tokenized_cond_text = tokenizer.encode( |
|
eot_token + convo_hist, |
|
add_special_tokens=False |
|
) |
|
|
|
|
|
|
|
|
|
_, pert_gen_tok_texts, _, _, _ = full_text_generation( |
|
model=model, |
|
tokenizer=tokenizer, |
|
context=tokenized_cond_text, |
|
device=device, |
|
num_samples=1, |
|
bag_of_words=bag_of_words, |
|
discrim=discrim, |
|
class_label=class_label, |
|
length=length, |
|
stepsize=stepsize, |
|
temperature=temperature, |
|
top_k=top_k, |
|
sample=sample, |
|
num_iterations=num_iterations, |
|
grad_length=grad_length, |
|
horizon_length=horizon_length, |
|
window_length=window_length, |
|
decay=decay, |
|
gamma=gamma, |
|
gm_scale=gm_scale, |
|
kl_scale=kl_scale, |
|
verbosity_level=verbosity_level, |
|
fp=fp, |
|
is_deep=is_deep, |
|
is_deeper=is_deeper, |
|
stop_eot=stop_eot |
|
) |
|
|
|
|
|
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): |
|
try: |
|
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0]) |
|
convo_hist_split = pert_gen_text.split(eot_token) |
|
html = "<div class='chatbot'>" |
|
for m, msg in enumerate(convo_hist_split[1:-1]): |
|
cls = "user" if m%2 == 0 else "bot" |
|
html += "<div class='msg {}'> {}</div>".format(cls, msg) |
|
html += "</div>" |
|
|
|
if len(convo_hist_split) > 4: convo_hist_split = convo_hist_split[-4:] |
|
convo_hist = eot_token.join(convo_hist_split) |
|
|
|
except: |
|
return "<div class='chatbot'>Error occured, chatbot restarted</div>", None |
|
|
|
return html, convo_hist |
|
|
|
css = """ |
|
.chatbox {display:flex;flex-direction:column} |
|
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} |
|
.msg.user {background-color:cornflowerblue;color:white} |
|
.msg.bot {background-color:lightgray;align-self:self-end} |
|
.footer {display:none !important} |
|
""" |
|
|
|
gr.Interface(fn=get_reply, |
|
theme="default", |
|
inputs=[gr.inputs.Textbox(placeholder="How are you?"), |
|
"state", |
|
gr.inputs.Number(default=2.56, label="Step"), |
|
gr.inputs.Number(default=5, label="Horizon"), |
|
gr.inputs.Number(default=10, label="Iterations"), |
|
gr.inputs.Number(default=2, label="Top_k")], |
|
outputs=["html", "state"], |
|
css=css).launch() |
|
|