boutsify / app.py
tombio's picture
Update app.py
cd308ea
raw history blame
No virus
7.51 kB
from io import BytesIO
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from torch import autocast
from contextlib import nullcontext
import requests
import functools
import random
import timm
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.extras import load_model_from_config, load_training_dir
import clip
from PIL import Image
from fastai.vision.all import *
import skimage
import mediapipe as mp
import numpy as np
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt")
config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml")
device = "cuda:0"
model = load_model_from_config(config, ckpt, device=device, verbose=False)
model = model.to(device).half()
clip_model, preprocess = clip.load("ViT-L/14", device=device)
gender_learn = load_learner('gender_model.pkl')
gender_labels = gender_learn.dls.vocab
beard_learn = load_learner('facial_hair_model.pkl')
beard_labels = beard_learn.dls.vocab
ethnic_learn = load_learner('ethnic_model.pkl')
ethnic_labels = ethnic_learn.dls.vocab
n_inputs = 5
torch.cuda.empty_cache()
@functools.lru_cache()
def get_url_im(t):
user_agent = {'User-agent': 'gradio-app'}
response = requests.get(t, headers=user_agent)
return Image.open(BytesIO(response.content))
@torch.no_grad()
def get_im_c(im, clip_model):
prompts = preprocess(im).to(device).unsqueeze(0)
return clip_model.encode_image(prompts).float()
@torch.no_grad()
def get_txt_c(txt, clip_model):
text = clip.tokenize([txt,]).to(device)
return clip_model.encode_text(text)
def get_txt_diff(txt1, txt2, clip_model):
return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model)
def to_im_list(x_samples_ddim):
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
ims = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
ims.append(Image.fromarray(x_sample.astype(np.uint8)))
return ims
@torch.no_grad()
def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50):
ddim_eta=0.0
precision_scope = autocast if precision=="autocast" else nullcontext
with precision_scope("cuda"):
shape = [4, h // 8, w // 8]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
return to_im_list(x_samples_ddim)
def run_image_mixer(args):
inps = []
for i in range(0, len(args)-4, n_inputs):
inps.append(args[i:i+n_inputs])
scale, n_samples, seed, steps = args[-4:]
h = w = 640
sampler = DDIMSampler(model)
# sampler = PLMSSampler(model)
torch.manual_seed(seed)
start_code = torch.randn(n_samples, 4, h//8, w//8, device=device)
conds = []
for b, t, im, s in zip(*inps):
print(b, t, im, s)
if b == "Image":
this_cond = s*get_im_c(im, clip_model)
elif b == "Text/URL":
if t.startswith("http"):
im = get_url_im(t)
this_cond = s*get_im_c(im, clip_model)
else:
this_cond = s*get_txt_c(t, clip_model)
else:
this_cond = torch.zeros((1, 768), device=device)
conds.append(this_cond)
conds = torch.cat(conds, dim=0).unsqueeze(0)
conds = conds.tile(n_samples, 1, 1)
ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps)
# return make_row(ims)
# Clear GPU memory cache so less likely to OOM
torch.cuda.empty_cache()
return ims[0]
def is_female(img):
pred,pred_idx,probs = gender_learn.predict(img)
return float(probs[0]) > float(probs[1])
def has_beard(img):
pred,pred_idx,probs = beard_learn.predict(img)
return float(probs[1]) > float(probs[0])
def ethnicity(img):
pred,pred_idx,probs = ethnic_learn.predict(img)
return pred
def find_face(img):
mp_face_detection = mp.solutions.face_detection
img_array = np.array(img)
print(img_array.shape)
with mp_face_detection.FaceDetection(
model_selection=0,
min_detection_confidence=0.75
) as face_detection:
results = face_detection.process(img_array)
if results.detections is not None:
return results.detections[0]
return None
def boutsify(person):
detected_face = find_face(person)
if detected_face is None:
print("Couldn't detect a face")
return
person_image = Image.fromarray(person)
bounding_box = detected_face.location_data.relative_bounding_box
width, height = person_image.size
width_margin = bounding_box.width * 0.5
height_margin = bounding_box.height * 0.8
# Setting the points for cropped image
left = max(0, (bounding_box.xmin - width_margin / 2.0) * width)
top = max(0, (bounding_box.ymin - height_margin / 1.2) * height)
right = min(width, left + (bounding_box.width + width_margin) * width)
bottom = min(height, top + (bounding_box.height + height_margin) * height)
# Cropped image of above dimension
# (It will not change original image)
face = person_image.crop((left, top, right, bottom))
face_array = np.array(face)
portrait_path = "bouts_m1.jpg"
female_detected = is_female(face_array)
ethnicity_prediction = ethnicity(face_array)
if ethnicity_prediction == "Black":
print("Colored person")
if female_detected:
print("This is a female portrait")
portrait_path = "bouts_fc1.jpg"
else:
portrait_path = "bouts_mc1.jpg"
else:
if female_detected:
print("This is a female portrait")
portrait_path = "bouts_f1.jpg"
else:
print("This is a male portrait, checking facial hair")
if has_beard(face_array):
print("The person has a beard")
portrait_path = "bouts_mb1.jpg"
inputs = [
"Image", "Image", "Image", "Image", "Nothing",
"","","","","",
Image.open(portrait_path).convert("RGB"),
Image.open("boutsback_o1.png").convert("RGB"),
Image.open("boutsback_o2.png").convert("RGB"),
face,
"",
1.1,1,1,1.4,1,
3.0, 1, random.randrange(0, 10000), 50,
]
#return person
return run_image_mixer(inputs)
import gradio
gradio_interface = gradio.Interface(
fn=boutsify,
inputs="image",
outputs="image",
title="Boutsify images",
description="Turn portraits into a painting in the style of Flemish master Dirck Bouts",
article="© iO Digital"
)
gradio_interface.queue().launch()