Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""message_bottle.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
torch.set_float32_matmul_precision('high') | |
from tqdm import tqdm | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
class BottleneckT5Autoencoder: | |
def __init__(self, model_path: str, device='cuda'): | |
self.device = device | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) | |
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) | |
self.model.eval() | |
# self.model = torch.compile(self.model) | |
def embed(self, text: str) -> torch.FloatTensor: | |
inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) | |
decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) | |
return self.model( | |
**inputs, | |
decoder_input_ids=decoder_inputs['input_ids'], | |
encode_only=True, | |
) | |
def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, length_penalty=10, min_new_tokens=30) -> str: | |
dummy_text = '.' | |
dummy = self.embed(dummy_text) | |
perturb_vector = latent - dummy | |
self.model.perturb_vector = perturb_vector | |
input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids | |
output = self.model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1, | |
length_penalty=length_penalty, | |
min_new_tokens=min_new_tokens, | |
# num_beams=8, | |
) | |
return self.tokenizer.decode(output[0], skip_special_tokens=True) | |
autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') | |
import gradio as gr | |
import numpy as np | |
from sklearn.svm import SVC | |
from sklearn.inspection import permutation_importance | |
from sklearn import preprocessing | |
import pandas as pd | |
import random | |
import time | |
dtype = torch.bfloat16 | |
torch.set_grad_enabled(False) | |
prompt_list = [p for p in list(set( | |
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
start_time = time.time() | |
####################### Setup Model | |
# TODO put back | |
# @spaces.GPU() | |
def generate(prompt, in_embs=None,): | |
if prompt != '': | |
print(prompt) | |
in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None | |
in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') | |
else: | |
print('From embeds.') | |
in_embs = in_embs / in_embs.abs().max() * .15 | |
text = autoencoder.generate_from_latent(in_embs.to('cuda'), temperature=.3, top_p=.99, min_new_tokens=5) | |
in_embs = autoencoder.embed(prompt) | |
return text, in_embs.to('cpu') | |
####################### | |
# TODO add to state instead of shared across all | |
glob_idx = 0 | |
def next_one(embs, ys, calibrate_prompts): | |
global glob_idx | |
glob_idx = glob_idx + 1 | |
with torch.no_grad(): | |
if len(calibrate_prompts) > 0: | |
print('######### Calibrating with sample prompts #########') | |
prompt = calibrate_prompts.pop(0) | |
print(prompt) | |
text, img_embs = generate(prompt) | |
embs += img_embs | |
print(len(embs)) | |
return text, embs, ys, calibrate_prompts | |
else: | |
print('######### Roaming #########') | |
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' | |
if len(list(set(ys))) <= 1: | |
embs.append(.01*torch.randn(2048)) | |
embs.append(.01*torch.randn(2048)) | |
ys.append(0) | |
ys.append(1) | |
if len(list(ys)) < 10: | |
embs += [.01*torch.randn(2048)] * 3 | |
ys += [0] * 3 | |
pos_indices = [i for i in range(len(embs)) if ys[i] == 1] | |
neg_indices = [i for i in range(len(embs)) if ys[i] == 0] | |
# the embs & ys stay tied by index but we shuffle to drop randomly | |
random.shuffle(pos_indices) | |
random.shuffle(neg_indices) | |
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80: | |
# pos_indices = pos_indices[32:] | |
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6: | |
pos_indices = pos_indices[5:] | |
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6: | |
neg_indices = neg_indices[5:] | |
if len(neg_indices) > 25: | |
neg_indices = neg_indices[1:] | |
print(len(pos_indices), len(neg_indices)) | |
indices = pos_indices + neg_indices | |
embs = [embs[i] for i in indices] | |
ys = [ys[i] for i in indices] | |
indices = list(range(len(embs))) | |
# also add the latest 0 and the latest 1 | |
has_0 = False | |
has_1 = False | |
for i in reversed(range(len(ys))): | |
if ys[i] == 0 and has_0 == False: | |
indices.append(i) | |
has_0 = True | |
elif ys[i] == 1 and has_1 == False: | |
indices.append(i) | |
has_1 = True | |
if has_0 and has_1: | |
break | |
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); | |
# this ends up adding a rating but losing an embedding, it seems. | |
# let's take off a rating if so to continue without indexing errors. | |
if len(ys) > len(embs): | |
print('ys are longer than embs; popping latest rating') | |
ys.pop(-1) | |
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu')) | |
scaler = preprocessing.StandardScaler().fit(feature_embs) | |
feature_embs = scaler.transform(feature_embs) | |
chosen_y = np.array([ys[i] for i in indices]) | |
print('Gathering coefficients') | |
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y) | |
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) | |
coef_ = coef_ / coef_.abs().max() * 3 | |
print(coef_.shape, 'COEF') | |
print('Gathered') | |
rng_prompt = random.choice(prompt_list) | |
w = 1# if len(embs) % 2 == 0 else 0 | |
im_emb = w * coef_.to(dtype=dtype) | |
prompt= '' if glob_idx % 3 != 0 else rng_prompt | |
text, im_emb = generate(prompt, im_emb) | |
embs += im_emb | |
return text, embs, ys, calibrate_prompts | |
def start(_, embs, ys, calibrate_prompts): | |
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
return [ | |
gr.Button(value='Like (L)', interactive=True), | |
gr.Button(value='Neither (Space)', interactive=True), | |
gr.Button(value='Dislike (A)', interactive=True), | |
gr.Button(value='Start', interactive=False), | |
text, | |
embs, | |
ys, | |
calibrate_prompts | |
] | |
def choose(text, choice, embs, ys, calibrate_prompts): | |
if choice == 'Like (L)': | |
choice = 1 | |
elif choice == 'Neither (Space)': | |
embs = embs[:-1] | |
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
return text, embs, ys, calibrate_prompts | |
else: | |
choice = 0 | |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen. | |
# TODO skip allowing rating | |
if text == None: | |
print('NSFW -- choice is disliked') | |
choice = 0 | |
ys += [choice]*1 | |
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
return text, embs, ys, calibrate_prompts | |
css = '''.gradio-container{max-width: 700px !important} | |
#description{text-align: center} | |
#description h1, #description h3{display: block} | |
#description p{margin-top: 0} | |
.fade-in-out {animation: fadeInOut 3s forwards} | |
@keyframes fadeInOut { | |
0% { | |
background: var(--bg-color); | |
} | |
100% { | |
background: var(--button-secondary-background-fill); | |
} | |
} | |
''' | |
js_head = ''' | |
<script> | |
document.addEventListener('keydown', function(event) { | |
if (event.key === 'a' || event.key === 'A') { | |
// Trigger click on 'dislike' if 'A' is pressed | |
document.getElementById('dislike').click(); | |
} else if (event.key === ' ' || event.keyCode === 32) { | |
// Trigger click on 'neither' if Spacebar is pressed | |
document.getElementById('neither').click(); | |
} else if (event.key === 'l' || event.key === 'L') { | |
// Trigger click on 'like' if 'L' is pressed | |
document.getElementById('like').click(); | |
} | |
}); | |
function fadeInOut(button, color) { | |
button.style.setProperty('--bg-color', color); | |
button.classList.remove('fade-in-out'); | |
void button.offsetWidth; // This line forces a repaint by accessing a DOM property | |
button.classList.add('fade-in-out'); | |
button.addEventListener('animationend', () => { | |
button.classList.remove('fade-in-out'); // Reset the animation state | |
}, {once: true}); | |
} | |
document.body.addEventListener('click', function(event) { | |
const target = event.target; | |
if (target.id === 'dislike') { | |
fadeInOut(target, '#ff1717'); | |
} else if (target.id === 'like') { | |
fadeInOut(target, '#006500'); | |
} else if (target.id === 'neither') { | |
fadeInOut(target, '#cccccc'); | |
} | |
}); | |
</script> | |
''' | |
with gr.Blocks(css=css, head=js_head) as demo: | |
gr.Markdown('''# Compass | |
### Generative Recommenders for Exporation of Text | |
Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). | |
''', elem_id="description") | |
embs = gr.State([]) | |
ys = gr.State([]) | |
calibrate_prompts = gr.State([ | |
'the moon is melting into my glass of tea', | |
'a sea slug -- pair of claws scuttling -- jelly fish glowing', | |
'an adorable creature. It may be a goblin or a pig or a slug.', | |
'an animation about a gorgeous nebula', | |
'a sketch of an impressive mountain by da vinci', | |
'a watercolor painting: the octopus writhes', | |
]) | |
def l(): | |
return None | |
with gr.Row(elem_id='output-image'): | |
text = gr.Textbox(interactive=False, elem_id="text") | |
with gr.Row(equal_height=True): | |
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") | |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") | |
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") | |
b1.click( | |
choose, | |
[text, b1, embs, ys, calibrate_prompts], | |
[text, embs, ys, calibrate_prompts] | |
) | |
b2.click( | |
choose, | |
[text, b2, embs, ys, calibrate_prompts], | |
[text, embs, ys, calibrate_prompts] | |
) | |
b3.click( | |
choose, | |
[text, b3, embs, ys, calibrate_prompts], | |
[text, embs, ys, calibrate_prompts] | |
) | |
with gr.Row(): | |
b4 = gr.Button(value='Start') | |
b4.click(start, | |
[b4, embs, ys, calibrate_prompts], | |
[b1, b2, b3, b4, text, embs, ys, calibrate_prompts]) | |
with gr.Row(): | |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br> | |
<div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating. | |
</ div> | |
<br><br> | |
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. | |
</ div>''') | |
demo.launch(share=True) | |