# -*- coding: utf-8 -*- """message_bottle.ipynb Automatically generated by Colab. """ DEVICE = 'cpu' import matplotlib.pyplot as plt import matplotlib import argparse import glob import logging import os import pickle import random import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm, trange from types import SimpleNamespace import sys sys.path.append('./Optimus/code/examples/big_ae/') sys.path.append('./Optimus/code/') from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer from pytorch_transformers import BertForLatentConnector, BertTokenizer from modules import VAE import torch import torch.nn as nn import torch.nn.functional as F torch.set_float32_matmul_precision('high') from tqdm import tqdm ################################################ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None): context = torch.tensor(context, dtype=torch.long, device=device) context = context.unsqueeze(0).repeat(num_samples, 1) generated = context with torch.no_grad(): while True: # for _ in trange(length): inputs = {'input_ids': generated, 'past': past} outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] / temperature filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) # pdb.set_trace() if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('')[0]: break return generated def latent_code_from_text(text,):# args): tokenized1 = tokenizer_encoder.encode(text) tokenized1 = [101] + tokenized1 + [102] coded1 = torch.Tensor([tokenized1]) coded1 =torch.Tensor.long(coded1) with torch.no_grad(): x0 = coded1 x0 = x0.to(DEVICE) pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1] mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1) latent_z = mean.squeeze(1) coded_length = len(tokenized1) return latent_z, coded_length # args def text_from_latent_code(latent_z): past = latent_z context_tokens = tokenizer_decoder.encode('') length = 128 # maximum length, but not used out = sample_sequence_conditional( model=model_vae.decoder, context=context_tokens, past=past, length=length, # Chunyuan: Fix length; or use to complete a sentence temperature=.5, top_k=100, top_p=.98, device=DEVICE, decoder_tokenizer=tokenizer_decoder ) text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True) text_x1 = text_x1.split()[1:-1] text_x1 = ' '.join(text_x1) return text_x1 ################################################ # Load model MODEL_CLASSES = { 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer), 'bert': (BertConfig, BertForLatentConnector, BertTokenizer) } latent_size = 768 model_path = './checkpoint-31250/checkpoint-full-31250/' encoder_path = './checkpoint-31250/checkpoint-encoder-31250/' decoder_path = './checkpoint-31250/checkpoint-decoder-31250/' block_size = 100 # Load a trained Encoder model and vocabulary that you have fine-tuned encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert'] model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size) tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True) model_encoder.to(DEVICE) if block_size <= 0: block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model block_size = min(block_size, tokenizer_encoder.max_len_single_sentence) # Load a trained Decoder model and vocabulary that you have fine-tuned decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2'] model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size) tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False) model_decoder.to(DEVICE) if block_size <= 0: block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model block_size = min(block_size, tokenizer_decoder.max_len_single_sentence) # Load full model output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/' checkpoint = torch.load(os.path.join(model_path, 'training.bin'), map_location=torch.device(DEVICE)) # Chunyuan: Add Padding token to GPT2 special_tokens_dict = {'pad_token': '', 'bos_token': '', 'eos_token': ''} num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'tokens to GPT2') model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. assert tokenizer_decoder.pad_token == '' # Evaluation model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':DEVICE})) model_vae.load_state_dict(checkpoint['model_state_dict']) print("Pre-trained Optimus is successfully loaded") model_vae.to(DEVICE).to(torch.bfloat16) model_vae = torch.compile(model_vae) model_vae.encoder = torch.compile(model_vae.encoder) model_vae.decoder = torch.compile(model_vae.decoder) l = latent_code_from_text('A photo of a mountain.')[0] t = text_from_latent_code(l) print(t, l, l.shape) ################################################ import gradio as gr import numpy as np from sklearn.svm import SVC from sklearn.inspection import permutation_importance from sklearn import preprocessing import pandas as pd import random import time dtype = torch.bfloat16 torch.set_grad_enabled(False) prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model # TODO put back # @spaces.GPU() def generate(prompt, in_embs=None,): if prompt != '': print(prompt) in_embs = in_embs / in_embs.abs().max() * .6 if in_embs != None else None in_embs = 1 * in_embs.to(DEVICE) + 1 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0] else: print('From embeds.') in_embs = in_embs / in_embs.abs().max() * .6 in_embs = in_embs.to(DEVICE).to(torch.bfloat16) plt.close('all') plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5) plt.savefig('real_im_emb_plot.jpg') text = ' '.join(text_from_latent_code(in_embs).replace( '', '').split()) in_embs = latent_code_from_text(text)[0] print(text) return text, in_embs.to('cpu') ####################### # TODO add to state instead of shared across all glob_idx = 0 def next_one(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) text, img_embs = generate(prompt) embs += img_embs print(len(embs)) return text, embs, ys, calibrate_prompts else: print('######### Roaming #########') # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(list(set(ys))) <= 1: embs.append(.01*torch.randn(latent_size)) embs.append(.01*torch.randn(latent_size)) ys.append(0) ys.append(1) if len(list(ys)) < 10: embs += [.01*torch.randn(latent_size)] * 3 ys += [0] * 3 pos_indices = [i for i in range(len(embs)) if ys[i] == 1] neg_indices = [i for i in range(len(embs)) if ys[i] == 0] # the embs & ys stay tied by index but we shuffle to drop randomly random.shuffle(pos_indices) random.shuffle(neg_indices) if len(neg_indices) > 25: neg_indices = neg_indices[1:] print(len(pos_indices), len(neg_indices)) indices = pos_indices + neg_indices embs = [embs[i] for i in indices] ys = [ys[i] for i in indices] indices = list(range(len(embs))) # also add the latest 0 and the latest 1 #has_0 = False #has_1 = False #for i in reversed(range(len(ys))): # if ys[i] == 0 and has_0 == False: # indices.append(i) # has_0 = True # elif ys[i] == 1 and has_1 == False: # indices.append(i) # has_1 = True # if has_0 and has_1: # break # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs): print('ys are longer than embs; popping latest rating') ys.pop(-1) feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) chosen_y = np.array([ys[i] for i in indices]) print('Gathering coefficients') lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y) coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) print(coef_.shape, 'COEF') print('Gathered') rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * coef_.to(dtype=dtype) prompt= '' if glob_idx % 3 != 0 else rng_prompt text, im_emb = generate(prompt, im_emb) embs += im_emb return text, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), text, embs, ys, calibrate_prompts ] def choose(text, choice, embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': embs = embs[:-1] text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts else: choice = 0 # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. # TODO skip allowing rating if text == None: print('NSFW -- choice is disliked') choice = 0 ys += [choice]*1 text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''# Compass ### Generative Recommenders for Exporation of Text Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ 'the moon is melting into my glass of tea', 'a sea slug -- pair of claws scuttling -- jelly fish glowing', 'an adorable creature. It may be a goblin or a pig or a slug.', 'an animation about a gorgeous nebula', 'a sketch of an impressive mountain by da vinci', 'a watercolor painting: the octopus writhes', ]) def l(): return None with gr.Row(elem_id='output-image'): text = gr.Textbox(interactive=False, elem_id="text") with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [text, b1, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b2.click( choose, [text, b2, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b3.click( choose, [text, b3, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, text, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''
You will calibrate for several prompts and then roam.


Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.

Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. ''') demo.launch(share=True)