multimodalart's picture
Performance PR
f33c43f verified
DEVICE = 'cpu'
import gradio as gr
import numpy as np
from sklearn.svm import LinearSVC
from sklearn import preprocessing
import pandas as pd
from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
from diffusers.models import ImageProjection
from patch_sdxl import SDEmb
import torch
import spaces
import random
import time
import torch
from urllib.request import urlopen
from PIL import Image
import requests
from io import BytesIO, StringIO
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
prompt_list = [p for p in list(set(
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
start_time = time.time()
####################### Setup Model
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
sdxl_lightening = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_2step_unet.safetensors"
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device='cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
output_hidden_state = False
#######################
@spaces.GPU
def predict(
prompt,
im_emb=None,
progress=gr.Progress(track_tqdm=True)
):
"""Run a single prediction on the model"""
with torch.no_grad():
if im_emb == None:
im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
image = pipe(
prompt=prompt,
ip_adapter_emb=im_emb.to('cuda'),
height=1024,
width=1024,
num_inference_steps=2,
guidance_scale=0,
).images[0]
im_emb, _ = pipe.encode_image(
image, 'cuda', 1, output_hidden_state
)
return image, im_emb.to(DEVICE)
# TODO add to state instead of shared across all
glob_idx = 0
def next_image(embs, ys, calibrate_prompts):
global glob_idx
glob_idx = glob_idx + 1
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1:
embs.append(.01*torch.randn(1, 1280))
embs.append(.01*torch.randn(1, 1280))
ys.append(0)
ys.append(1)
with torch.no_grad():
if len(calibrate_prompts) > 0:
print('######### Calibrating with sample prompts #########')
prompt = calibrate_prompts.pop(0)
print(prompt)
image, img_emb = predict(prompt)
embs.append(img_emb)
return image, embs, ys, calibrate_prompts
else:
print('######### Roaming #########')
# sample only as many negatives as there are positives
indices = range(len(ys))
pos_indices = [i for i in indices if ys[i] == 1]
neg_indices = [i for i in indices if ys[i] == 0]
lower = min(len(pos_indices), len(neg_indices))
neg_indices = random.sample(neg_indices, lower)
pos_indices = random.sample(pos_indices, lower)
cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices]
cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices]
feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs])
scaler = preprocessing.StandardScaler().fit(feature_embs)
feature_embs = scaler.transform(feature_embs)
print(np.array(feature_embs).shape, np.array(ys).shape)
lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys))
lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0)
rng_prompt = random.choice(prompt_list)
w = 1# if len(embs) % 2 == 0 else 0
im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
print(prompt)
image, im_emb = predict(prompt, im_emb)
embs.append(im_emb)
return image, embs, ys, calibrate_prompts
def start(_, embs, ys, calibrate_prompts):
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return [
gr.Button(value='Like (L)', interactive=True),
gr.Button(value='Neither (Space)', interactive=True),
gr.Button(value='Dislike (A)', interactive=True),
gr.Button(value='Start', interactive=False),
image,
embs,
ys,
calibrate_prompts
]
def choose(choice, embs, ys, calibrate_prompts):
if choice == 'Like':
choice = 1
elif choice == 'Neither':
_ = embs.pop(-1)
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
else:
choice = 0
ys.append(choice)
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
return img, embs, ys, calibrate_prompts
css = '''.gradio-container{max-width: 700px !important}
#description{text-align: center}
#description h1{display: block}
#description p{margin-top: 0}
'''
js = '''
<script>
document.addEventListener('keydown', function(event) {
if (event.key === 'a' || event.key === 'A') {
// Trigger click on 'dislike' if 'A' is pressed
document.getElementById('dislike').click();
} else if (event.key === ' ' || event.keyCode === 32) {
// Trigger click on 'neither' if Spacebar is pressed
document.getElementById('neither').click();
} else if (event.key === 'l' || event.key === 'L') {
// Trigger click on 'like' if 'L' is pressed
document.getElementById('like').click();
}
});
</script>
'''
with gr.Blocks(css=css, head=js) as demo:
gr.Markdown('''# Generative Recommenders
Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/)
''', elem_id="description")
embs = gr.State([])
ys = gr.State([])
calibrate_prompts = gr.State([
"4k photo",
'surrealist art',
# 'a psychedelic, fractal view',
'a beautiful collage',
'abstract art',
'an eldritch image',
'a sketch',
# 'a city full of darkness and graffiti',
'',
])
with gr.Row(elem_id='output-image'):
img = gr.Image(interactive=False, elem_id='output-image',width=700)
with gr.Row(equal_height=True):
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
b1.click(
choose,
[b1, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b2.click(
choose,
[b2, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
b3.click(
choose,
[b3, embs, ys, calibrate_prompts],
[img, embs, ys, calibrate_prompts]
)
with gr.Row():
b4 = gr.Button(value='Start')
b4.click(start,
[b4, embs, ys, calibrate_prompts],
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
with gr.Row():
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam.</ div>''')
demo.launch() # Share your demo with just 1 extra parameter 🚀