Spaces:
Paused
Paused
from io import BytesIO | |
import torch | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from torch import autocast | |
from contextlib import nullcontext | |
import requests | |
import functools | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.extras import load_model_from_config, load_training_dir | |
import clip | |
from PIL import Image | |
from fastai.vision.all import * | |
import skimage | |
from huggingface_hub import hf_hub_download | |
ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt", cache_dir="/data/.cache") | |
config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml", cache_dir="/data/.cache") | |
device = "cuda:0" | |
model = load_model_from_config(config, ckpt, device=device, verbose=False) | |
model = model.to(device).half() | |
clip_model, preprocess = clip.load("ViT-L/14", device=device) | |
gender_learn = load_learner('gender_model.pkl') | |
gender_labels = gender_learn.dls.vocab | |
n_inputs = 5 | |
torch.cuda.empty_cache() | |
def get_url_im(t): | |
user_agent = {'User-agent': 'gradio-app'} | |
response = requests.get(t, headers=user_agent) | |
return Image.open(BytesIO(response.content)) | |
def get_im_c(im, clip_model): | |
prompts = preprocess(im).to(device).unsqueeze(0) | |
return clip_model.encode_image(prompts).float() | |
def get_txt_c(txt, clip_model): | |
text = clip.tokenize([txt,]).to(device) | |
return clip_model.encode_text(text) | |
def get_txt_diff(txt1, txt2, clip_model): | |
return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model) | |
def to_im_list(x_samples_ddim): | |
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
ims = [] | |
for x_sample in x_samples_ddim: | |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
ims.append(Image.fromarray(x_sample.astype(np.uint8))) | |
return ims | |
def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50): | |
ddim_eta=0.0 | |
precision_scope = autocast if precision=="autocast" else nullcontext | |
with precision_scope("cuda"): | |
shape = [4, h // 8, w // 8] | |
samples_ddim, _ = sampler.sample(S=ddim_steps, | |
conditioning=c, | |
batch_size=c.shape[0], | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc, | |
eta=ddim_eta, | |
x_T=start_code) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
return to_im_list(x_samples_ddim) | |
def run_image_mixer(args): | |
inps = [] | |
for i in range(0, len(args)-4, n_inputs): | |
inps.append(args[i:i+n_inputs]) | |
scale, n_samples, seed, steps = args[-4:] | |
h = w = 640 | |
sampler = DDIMSampler(model) | |
# sampler = PLMSSampler(model) | |
torch.manual_seed(seed) | |
start_code = torch.randn(n_samples, 4, h//8, w//8, device=device) | |
conds = [] | |
for b, t, im, s in zip(*inps): | |
print(b, t, im, s) | |
if b == "Image": | |
this_cond = s*get_im_c(im, clip_model) | |
elif b == "Text/URL": | |
if t.startswith("http"): | |
im = get_url_im(t) | |
this_cond = s*get_im_c(im, clip_model) | |
else: | |
this_cond = s*get_txt_c(t, clip_model) | |
else: | |
this_cond = torch.zeros((1, 768), device=device) | |
conds.append(this_cond) | |
conds = torch.cat(conds, dim=0).unsqueeze(0) | |
conds = conds.tile(n_samples, 1, 1) | |
ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps) | |
# return make_row(ims) | |
# Clear GPU memory cache so less likely to OOM | |
torch.cuda.empty_cache() | |
return ims[0] | |
def is_female(img): | |
pred,pred_idx,probs = gender_learn.predict(img) | |
return float(probs[0]) > float(probs[1]) | |
import gradio | |
def boutsify(person): | |
female_detected = is_female(person) | |
if female_detected: | |
print("Picture of a female") | |
person_image = Image.fromarray(person) | |
inputs = [ | |
"Image", "Image", "Text/URL", "Image", "Nothing", | |
"","","flowers","","", | |
Image.open("ex2-1.jpeg").convert("RGB"), | |
Image.open("ex2-2.jpeg").convert("RGB"), | |
Image.open("blonder.jpeg").convert("RGB"), | |
person_image, | |
Image.open("blonder.jpeg").convert("RGB"), | |
1,1,1.5,1.4,1, | |
3.0, 1, 0, 40, | |
] | |
return run_image_mixer(inputs) | |
gradio_interface = gradio.Interface( | |
fn=boutsify, | |
inputs="image", | |
outputs="image", | |
title="Boutsify images", | |
description="Turn portraits into a painting in the style of Flemish master Dirck Bouts", | |
article="© iO Digital" | |
) | |
gradio_interface.launch() |