Babel / app.py
rynmurdock's picture
init
c5ca37a
raw
history blame
12.3 kB
# -*- coding: utf-8 -*-
"""message_bottle.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_float32_matmul_precision('high')
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
class BottleneckT5Autoencoder:
def __init__(self, model_path: str, device='cuda'):
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16)
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device)
self.model.eval()
# self.model = torch.compile(self.model)
def embed(self, text: str) -> torch.FloatTensor:
inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device)
decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device)
return self.model(
**inputs,
decoder_input_ids=decoder_inputs['input_ids'],
encode_only=True,
)
def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, length_penalty=10, min_new_tokens=30) -> str:
dummy_text = '.'
dummy = self.embed(dummy_text)
perturb_vector = latent - dummy
self.model.perturb_vector = perturb_vector
input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids
output = self.model.generate(
input_ids=input_ids,
max_length=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
length_penalty=length_penalty,
min_new_tokens=min_new_tokens,
# num_beams=8,
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia')
import gradio as gr
import numpy as np
from sklearn.svm import SVC
from sklearn.inspection import permutation_importance
from sklearn import preprocessing
import pandas as pd
import random
import time
dtype = torch.bfloat16
torch.set_grad_enabled(False)
prompt_list = [p for p in list(set(
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
start_time = time.time()
####################### Setup Model
# TODO put back
# @spaces.GPU()
def generate(prompt, in_embs=None,):
if prompt != '':
print(prompt)
in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None
in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda')
else:
print('From embeds.')
in_embs = in_embs / in_embs.abs().max() * .15
text = autoencoder.generate_from_latent(in_embs.to('cuda'), temperature=.3, top_p=.99, min_new_tokens=5)
in_embs = autoencoder.embed(prompt)
return text, in_embs.to('cpu')
#######################
# TODO add to state instead of shared across all
glob_idx = 0
def next_one(embs, ys, calibrate_prompts):
global glob_idx
glob_idx = glob_idx + 1
with torch.no_grad():
if len(calibrate_prompts) > 0:
print('######### Calibrating with sample prompts #########')
prompt = calibrate_prompts.pop(0)
print(prompt)
text, img_embs = generate(prompt)
embs += img_embs
print(len(embs))
return text, embs, ys, calibrate_prompts
else:
print('######### Roaming #########')
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
if len(list(set(ys))) <= 1:
embs.append(.01*torch.randn(2048))
embs.append(.01*torch.randn(2048))
ys.append(0)
ys.append(1)
if len(list(ys)) < 10:
embs += [.01*torch.randn(2048)] * 3
ys += [0] * 3
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
# the embs & ys stay tied by index but we shuffle to drop randomly
random.shuffle(pos_indices)
random.shuffle(neg_indices)
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
# pos_indices = pos_indices[32:]
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6:
pos_indices = pos_indices[5:]
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6:
neg_indices = neg_indices[5:]
if len(neg_indices) > 25:
neg_indices = neg_indices[1:]
print(len(pos_indices), len(neg_indices))
indices = pos_indices + neg_indices
embs = [embs[i] for i in indices]
ys = [ys[i] for i in indices]
indices = list(range(len(embs)))
# also add the latest 0 and the latest 1
has_0 = False
has_1 = False
for i in reversed(range(len(ys))):
if ys[i] == 0 and has_0 == False:
indices.append(i)
has_0 = True
elif ys[i] == 1 and has_1 == False:
indices.append(i)
has_1 = True
if has_0 and has_1:
break
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
# this ends up adding a rating but losing an embedding, it seems.
# let's take off a rating if so to continue without indexing errors.
if len(ys) > len(embs):
print('ys are longer than embs; popping latest rating')
ys.pop(-1)
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu'))
scaler = preprocessing.StandardScaler().fit(feature_embs)
feature_embs = scaler.transform(feature_embs)
chosen_y = np.array([ys[i] for i in indices])
print('Gathering coefficients')
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y)
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
coef_ = coef_ / coef_.abs().max() * 3
print(coef_.shape, 'COEF')
print('Gathered')
rng_prompt = random.choice(prompt_list)
w = 1# if len(embs) % 2 == 0 else 0
im_emb = w * coef_.to(dtype=dtype)
prompt= '' if glob_idx % 3 != 0 else rng_prompt
text, im_emb = generate(prompt, im_emb)
embs += im_emb
return text, embs, ys, calibrate_prompts
def start(_, embs, ys, calibrate_prompts):
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
return [
gr.Button(value='Like (L)', interactive=True),
gr.Button(value='Neither (Space)', interactive=True),
gr.Button(value='Dislike (A)', interactive=True),
gr.Button(value='Start', interactive=False),
text,
embs,
ys,
calibrate_prompts
]
def choose(text, choice, embs, ys, calibrate_prompts):
if choice == 'Like (L)':
choice = 1
elif choice == 'Neither (Space)':
embs = embs[:-1]
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
return text, embs, ys, calibrate_prompts
else:
choice = 0
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
# TODO skip allowing rating
if text == None:
print('NSFW -- choice is disliked')
choice = 0
ys += [choice]*1
text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts)
return text, embs, ys, calibrate_prompts
css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1, #description h3{display: block}
#description p{margin-top: 0}
.fade-in-out {animation: fadeInOut 3s forwards}
@keyframes fadeInOut {
0% {
background: var(--bg-color);
}
100% {
background: var(--button-secondary-background-fill);
}
}
'''
js_head = '''
<script>
document.addEventListener('keydown', function(event) {
if (event.key === 'a' || event.key === 'A') {
// Trigger click on 'dislike' if 'A' is pressed
document.getElementById('dislike').click();
} else if (event.key === ' ' || event.keyCode === 32) {
// Trigger click on 'neither' if Spacebar is pressed
document.getElementById('neither').click();
} else if (event.key === 'l' || event.key === 'L') {
// Trigger click on 'like' if 'L' is pressed
document.getElementById('like').click();
}
});
function fadeInOut(button, color) {
button.style.setProperty('--bg-color', color);
button.classList.remove('fade-in-out');
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
button.classList.add('fade-in-out');
button.addEventListener('animationend', () => {
button.classList.remove('fade-in-out'); // Reset the animation state
}, {once: true});
}
document.body.addEventListener('click', function(event) {
const target = event.target;
if (target.id === 'dislike') {
fadeInOut(target, '#ff1717');
} else if (target.id === 'like') {
fadeInOut(target, '#006500');
} else if (target.id === 'neither') {
fadeInOut(target, '#cccccc');
}
});
</script>
'''
with gr.Blocks(css=css, head=js_head) as demo:
gr.Markdown('''# Compass
### Generative Recommenders for Exporation of Text
Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
''', elem_id="description")
embs = gr.State([])
ys = gr.State([])
calibrate_prompts = gr.State([
'the moon is melting into my glass of tea',
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
'an adorable creature. It may be a goblin or a pig or a slug.',
'an animation about a gorgeous nebula',
'a sketch of an impressive mountain by da vinci',
'a watercolor painting: the octopus writhes',
])
def l():
return None
with gr.Row(elem_id='output-image'):
text = gr.Textbox(interactive=False, elem_id="text")
with gr.Row(equal_height=True):
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
b1.click(
choose,
[text, b1, embs, ys, calibrate_prompts],
[text, embs, ys, calibrate_prompts]
)
b2.click(
choose,
[text, b2, embs, ys, calibrate_prompts],
[text, embs, ys, calibrate_prompts]
)
b3.click(
choose,
[text, b3, embs, ys, calibrate_prompts],
[text, embs, ys, calibrate_prompts]
)
with gr.Row():
b4 = gr.Button(value='Start')
b4.click(start,
[b4, embs, ys, calibrate_prompts],
[b1, b2, b3, b4, text, embs, ys, calibrate_prompts])
with gr.Row():
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
<div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.
</ div>
<br><br>
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
</ div>''')
demo.launch(share=True)