boutsify / app.py
tombio's picture
add detection for colored persons
d8dabba
raw history blame
No virus
6.12 kB
from io import BytesIO
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from torch import autocast
from contextlib import nullcontext
import requests
import functools
import random
import timm
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.extras import load_model_from_config, load_training_dir
import clip
from PIL import Image
from fastai.vision.all import *
import skimage
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt", cache_dir="/data/.cache")
config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml", cache_dir="/data/.cache")
device = "cuda:0"
model = load_model_from_config(config, ckpt, device=device, verbose=False)
model = model.to(device).half()
clip_model, preprocess = clip.load("ViT-L/14", device=device)
gender_learn = load_learner('gender_model.pkl')
gender_labels = gender_learn.dls.vocab
beard_learn = load_learner('facial_hair_model.pkl')
beard_labels = beard_learn.dls.vocab
ethnic_learn = load_learner('ethnic_model.pkl')
ethnic_labels = ethnic_learn.dls.vocab
n_inputs = 5
torch.cuda.empty_cache()
@functools.lru_cache()
def get_url_im(t):
user_agent = {'User-agent': 'gradio-app'}
response = requests.get(t, headers=user_agent)
return Image.open(BytesIO(response.content))
@torch.no_grad()
def get_im_c(im, clip_model):
prompts = preprocess(im).to(device).unsqueeze(0)
return clip_model.encode_image(prompts).float()
@torch.no_grad()
def get_txt_c(txt, clip_model):
text = clip.tokenize([txt,]).to(device)
return clip_model.encode_text(text)
def get_txt_diff(txt1, txt2, clip_model):
return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model)
def to_im_list(x_samples_ddim):
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
ims = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
ims.append(Image.fromarray(x_sample.astype(np.uint8)))
return ims
@torch.no_grad()
def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50):
ddim_eta=0.0
precision_scope = autocast if precision=="autocast" else nullcontext
with precision_scope("cuda"):
shape = [4, h // 8, w // 8]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
return to_im_list(x_samples_ddim)
def run_image_mixer(args):
inps = []
for i in range(0, len(args)-4, n_inputs):
inps.append(args[i:i+n_inputs])
scale, n_samples, seed, steps = args[-4:]
h = w = 640
sampler = DDIMSampler(model)
# sampler = PLMSSampler(model)
torch.manual_seed(seed)
start_code = torch.randn(n_samples, 4, h//8, w//8, device=device)
conds = []
for b, t, im, s in zip(*inps):
print(b, t, im, s)
if b == "Image":
this_cond = s*get_im_c(im, clip_model)
elif b == "Text/URL":
if t.startswith("http"):
im = get_url_im(t)
this_cond = s*get_im_c(im, clip_model)
else:
this_cond = s*get_txt_c(t, clip_model)
else:
this_cond = torch.zeros((1, 768), device=device)
conds.append(this_cond)
conds = torch.cat(conds, dim=0).unsqueeze(0)
conds = conds.tile(n_samples, 1, 1)
ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps)
# return make_row(ims)
# Clear GPU memory cache so less likely to OOM
torch.cuda.empty_cache()
return ims[0]
def is_female(img):
pred,pred_idx,probs = gender_learn.predict(img)
return float(probs[0]) > float(probs[1])
def has_beard(img):
pred,pred_idx,probs = beard_learn.predict(img)
return float(probs[1]) > float(probs[0])
def ethnicity(img):
pred,pred_idx,probs = ethnic_learn.predict(img)
return pred
import gradio
def boutsify(person):
portrait_path = "bouts_m1.jpg"
female_detected = is_female(person)
ethnicity_prediction = ethnicity(person)
if ethnicity_prediction == "Black":
print("Colored person")
if female_detected:
print("This is a female portrait")
portrait_path = "bouts_fc1.jpg"
else:
portrait_path = "bouts_mc1.jpg"
else:
if female_detected:
print("This is a female portrait")
portrait_path = "bouts_f1.jpg"
else:
print("This is a male portrait, checking facial hair")
if has_beard(person):
print("The person has a beard")
portrait_path = "bouts_mb1.jpg"
person_image = Image.fromarray(person)
inputs = [
"Image", "Image", "Image", "Image", "Nothing",
"","","","","",
Image.open(portrait_path).convert("RGB"),
Image.open("boutsback_o1.png").convert("RGB"),
Image.open("boutsback_o2.png").convert("RGB"),
person_image,
"",
1.1,1,1,1.4,1,
3.0, 1, random.randrange(0, 10000), 50,
]
#return person
return run_image_mixer(inputs)
gradio_interface = gradio.Interface(
fn=boutsify,
inputs="image",
outputs="image",
title="Boutsify images",
description="Turn portraits into a painting in the style of Flemish master Dirck Bouts",
article="© iO Digital"
)
gradio_interface.launch()