Spaces:
Paused
Paused
from io import BytesIO | |
import torch | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from torch import autocast | |
from contextlib import nullcontext | |
import requests | |
import functools | |
import random | |
import timm | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.extras import load_model_from_config, load_training_dir | |
import clip | |
from PIL import Image | |
from fastai.vision.all import * | |
import skimage | |
import mediapipe as mp | |
import numpy as np | |
#import pathlib | |
#temp = pathlib.PosixPath | |
#pathlib.PosixPath = pathlib.WindowsPath | |
from huggingface_hub import hf_hub_download | |
ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt") | |
config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml") | |
device = "cuda:0" | |
model = load_model_from_config(config, ckpt, device=device, verbose=False) | |
model = model.to(device).half() | |
clip_model, preprocess = clip.load("ViT-L/14", device=device) | |
gender_learn = load_learner('gender_model.pkl') | |
gender_labels = gender_learn.dls.vocab | |
beard_learn = load_learner('facial_hair_model.pkl') | |
beard_labels = beard_learn.dls.vocab | |
ethnic_learn = load_learner('ethnic_model.pkl') | |
ethnic_labels = ethnic_learn.dls.vocab | |
n_inputs = 5 | |
torch.cuda.empty_cache() | |
def get_url_im(t): | |
user_agent = {'User-agent': 'gradio-app'} | |
response = requests.get(t, headers=user_agent) | |
return Image.open(BytesIO(response.content)) | |
def get_im_c(im, clip_model): | |
prompts = preprocess(im).to(device).unsqueeze(0) | |
return clip_model.encode_image(prompts).float() | |
def get_txt_c(txt, clip_model): | |
text = clip.tokenize([txt,]).to(device) | |
return clip_model.encode_text(text) | |
def get_txt_diff(txt1, txt2, clip_model): | |
return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model) | |
def to_im_list(x_samples_ddim): | |
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
ims = [] | |
for x_sample in x_samples_ddim: | |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
ims.append(Image.fromarray(x_sample.astype(np.uint8))) | |
return ims | |
def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50): | |
ddim_eta=0.0 | |
precision_scope = autocast if precision=="autocast" else nullcontext | |
with precision_scope("cuda"): | |
shape = [4, h // 8, w // 8] | |
samples_ddim, _ = sampler.sample(S=ddim_steps, | |
conditioning=c, | |
batch_size=c.shape[0], | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc, | |
eta=ddim_eta, | |
x_T=start_code) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
return to_im_list(x_samples_ddim) | |
def run_image_mixer(args): | |
inps = [] | |
for i in range(0, len(args)-4, n_inputs): | |
inps.append(args[i:i+n_inputs]) | |
scale, n_samples, seed, steps = args[-4:] | |
h = w = 640 | |
sampler = DDIMSampler(model) | |
# sampler = PLMSSampler(model) | |
torch.manual_seed(seed) | |
start_code = torch.randn(n_samples, 4, h//8, w//8, device=device) | |
conds = [] | |
for b, t, im, s in zip(*inps): | |
print(b, t, im, s) | |
if b == "Image": | |
this_cond = s*get_im_c(im, clip_model) | |
elif b == "Text/URL": | |
if t.startswith("http"): | |
im = get_url_im(t) | |
this_cond = s*get_im_c(im, clip_model) | |
else: | |
this_cond = s*get_txt_c(t, clip_model) | |
else: | |
this_cond = torch.zeros((1, 768), device=device) | |
conds.append(this_cond) | |
conds = torch.cat(conds, dim=0).unsqueeze(0) | |
conds = conds.tile(n_samples, 1, 1) | |
ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps) | |
# return make_row(ims) | |
# Clear GPU memory cache so less likely to OOM | |
torch.cuda.empty_cache() | |
return ims[0] | |
def is_female(img): | |
pred,pred_idx,probs = gender_learn.predict(img) | |
return float(probs[0]) > float(probs[1]) | |
def has_beard(img): | |
pred,pred_idx,probs = beard_learn.predict(img) | |
return float(probs[1]) > float(probs[0]) | |
def ethnicity(img): | |
pred,pred_idx,probs = ethnic_learn.predict(img) | |
return pred | |
def find_face(img): | |
mp_face_detection = mp.solutions.face_detection | |
img_array = np.array(img) | |
print(img_array.shape) | |
with mp_face_detection.FaceDetection( | |
model_selection=0, | |
min_detection_confidence=0.75 | |
) as face_detection: | |
results = face_detection.process(img_array) | |
if results.detections is not None: | |
return results.detections[0] | |
return None | |
def boutsify(person): | |
detected_face = find_face(person) | |
if detected_face is None: | |
print("Couldn't detect a face") | |
return | |
person_image = Image.fromarray(person) | |
bounding_box = detected_face.location_data.relative_bounding_box | |
width, height = person_image.size | |
width_margin = bounding_box.width * 0.5 | |
height_margin = bounding_box.height * 0.8 | |
# Setting the points for cropped image | |
left = max(0, (bounding_box.xmin - width_margin / 2.0) * width) | |
top = max(0, (bounding_box.ymin - height_margin / 1.2) * height) | |
right = min(width, left + (bounding_box.width + width_margin) * width) | |
bottom = min(height, top + (bounding_box.height + height_margin) * height) | |
# Cropped image of above dimension | |
# (It will not change original image) | |
face = person_image.crop((left, top, right, bottom)) | |
face_array = np.array(face) | |
portrait_path = "bouts_m1.jpg" | |
female_detected = is_female(face_array) | |
ethnicity_prediction = ethnicity(face_array) | |
if ethnicity_prediction == "Black": | |
print("Colored person") | |
if female_detected: | |
print("This is a female portrait") | |
portrait_path = "bouts_fc1.jpg" | |
else: | |
portrait_path = "bouts_mc1.jpg" | |
else: | |
if female_detected: | |
print("This is a female portrait") | |
portrait_path = "bouts_f1.jpg" | |
else: | |
print("This is a male portrait, checking facial hair") | |
if has_beard(face_array): | |
print("The person has a beard") | |
portrait_path = "bouts_mb1.jpg" | |
inputs = [ | |
"Image", "Image", "Image", "Image", "Nothing", | |
"","","","","", | |
Image.open(portrait_path).convert("RGB"), | |
Image.open("boutsback_o1.png").convert("RGB"), | |
Image.open("boutsback_o2.png").convert("RGB"), | |
face, | |
"", | |
1.1,1,1,1.4,1, | |
3.0, 1, random.randrange(0, 10000), 50, | |
] | |
#return person | |
return run_image_mixer(inputs) | |
import gradio | |
gradio_interface = gradio.Interface( | |
fn=boutsify, | |
inputs="image", | |
outputs="image", | |
title="Boutsify images", | |
description="Turn portraits into a painting in the style of Flemish master Dirck Bouts", | |
article="© iO Digital" | |
) | |
gradio_interface.queue().launch() | |