boutsify / app.py
tombio's picture
fix image processing mistake
6e1757f
raw history blame
No virus
5.18 kB
from io import BytesIO
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from torch import autocast
from contextlib import nullcontext
import requests
import functools
import random
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.extras import load_model_from_config, load_training_dir
import clip
from PIL import Image
from fastai.vision.all import *
import skimage
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt", cache_dir="/data/.cache")
config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml", cache_dir="/data/.cache")
device = "cuda:0"
model = load_model_from_config(config, ckpt, device=device, verbose=False)
model = model.to(device).half()
clip_model, preprocess = clip.load("ViT-L/14", device=device)
gender_learn = load_learner('gender_model.pkl')
gender_labels = gender_learn.dls.vocab
n_inputs = 5
torch.cuda.empty_cache()
@functools.lru_cache()
def get_url_im(t):
user_agent = {'User-agent': 'gradio-app'}
response = requests.get(t, headers=user_agent)
return Image.open(BytesIO(response.content))
@torch.no_grad()
def get_im_c(im, clip_model):
prompts = preprocess(im).to(device).unsqueeze(0)
return clip_model.encode_image(prompts).float()
@torch.no_grad()
def get_txt_c(txt, clip_model):
text = clip.tokenize([txt,]).to(device)
return clip_model.encode_text(text)
def get_txt_diff(txt1, txt2, clip_model):
return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model)
def to_im_list(x_samples_ddim):
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
ims = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
ims.append(Image.fromarray(x_sample.astype(np.uint8)))
return ims
@torch.no_grad()
def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50):
ddim_eta=0.0
precision_scope = autocast if precision=="autocast" else nullcontext
with precision_scope("cuda"):
shape = [4, h // 8, w // 8]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
return to_im_list(x_samples_ddim)
def run_image_mixer(args):
inps = []
for i in range(0, len(args)-4, n_inputs):
inps.append(args[i:i+n_inputs])
scale, n_samples, seed, steps = args[-4:]
h = w = 640
sampler = DDIMSampler(model)
# sampler = PLMSSampler(model)
torch.manual_seed(seed)
start_code = torch.randn(n_samples, 4, h//8, w//8, device=device)
conds = []
for b, t, im, s in zip(*inps):
print(b, t, im, s)
if b == "Image":
this_cond = s*get_im_c(im, clip_model)
elif b == "Text/URL":
if t.startswith("http"):
im = get_url_im(t)
this_cond = s*get_im_c(im, clip_model)
else:
this_cond = s*get_txt_c(t, clip_model)
else:
this_cond = torch.zeros((1, 768), device=device)
conds.append(this_cond)
conds = torch.cat(conds, dim=0).unsqueeze(0)
conds = conds.tile(n_samples, 1, 1)
ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps)
# return make_row(ims)
# Clear GPU memory cache so less likely to OOM
torch.cuda.empty_cache()
return ims[0]
def is_female(img):
pred,pred_idx,probs = gender_learn.predict(img)
return float(probs[0]) > float(probs[1])
import gradio
def boutsify(person):
female_detected = is_female(person)
if female_detected:
print("Picture of a female")
person_image = Image.fromarray(person)
inputs = [
"Image", "Image", "Image", "Image", "Nothing",
"","","","","",
Image.open("bouts_f1.jpg").convert("RGB") if female_detected else Image.open("bouts_m1.jpg").convert("RGB"),
Image.open("bouts_banner.png").convert("RGB"),
Image.open("bouts_city.png").convert("RGB"),
person_image,
Image.open("bouts_m24.jpg").convert("RGB"),
1.2,1,1,1.4,1,
3.0, 1, random.randrange(0, 10000), 40,
]
return run_image_mixer(inputs)
gradio_interface = gradio.Interface(
fn=boutsify,
inputs="image",
outputs="image",
title="Boutsify images",
description="Turn portraits into a painting in the style of Flemish master Dirck Bouts",
article="© iO Digital"
)
gradio_interface.launch()