from io import BytesIO import torch import numpy as np from PIL import Image from einops import rearrange from torch import autocast from contextlib import nullcontext import requests import functools import random import timm from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.extras import load_model_from_config, load_training_dir import clip from PIL import Image from fastai.vision.all import * import skimage import mediapipe as mp import numpy as np from huggingface_hub import hf_hub_download ckpt = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-pruned.ckpt", cache_dir="/data/.cache") config = hf_hub_download(repo_id="lambdalabs/image-mixer", filename="image-mixer-config.yaml", cache_dir="/data/.cache") device = "cuda:0" model = load_model_from_config(config, ckpt, device=device, verbose=False) model = model.to(device).half() clip_model, preprocess = clip.load("ViT-L/14", device=device) gender_learn = load_learner('gender_model.pkl') gender_labels = gender_learn.dls.vocab beard_learn = load_learner('facial_hair_model.pkl') beard_labels = beard_learn.dls.vocab ethnic_learn = load_learner('ethnic_model.pkl') ethnic_labels = ethnic_learn.dls.vocab n_inputs = 5 torch.cuda.empty_cache() @functools.lru_cache() def get_url_im(t): user_agent = {'User-agent': 'gradio-app'} response = requests.get(t, headers=user_agent) return Image.open(BytesIO(response.content)) @torch.no_grad() def get_im_c(im, clip_model): prompts = preprocess(im).to(device).unsqueeze(0) return clip_model.encode_image(prompts).float() @torch.no_grad() def get_txt_c(txt, clip_model): text = clip.tokenize([txt,]).to(device) return clip_model.encode_text(text) def get_txt_diff(txt1, txt2, clip_model): return get_txt_c(txt1, clip_model) - get_txt_c(txt2, clip_model) def to_im_list(x_samples_ddim): x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) ims = [] for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') ims.append(Image.fromarray(x_sample.astype(np.uint8))) return ims @torch.no_grad() def sample(sampler, model, c, uc, scale, start_code, h=512, w=512, precision="autocast",ddim_steps=50): ddim_eta=0.0 precision_scope = autocast if precision=="autocast" else nullcontext with precision_scope("cuda"): shape = [4, h // 8, w // 8] samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=start_code) x_samples_ddim = model.decode_first_stage(samples_ddim) return to_im_list(x_samples_ddim) def run_image_mixer(args): inps = [] for i in range(0, len(args)-4, n_inputs): inps.append(args[i:i+n_inputs]) scale, n_samples, seed, steps = args[-4:] h = w = 640 sampler = DDIMSampler(model) # sampler = PLMSSampler(model) torch.manual_seed(seed) start_code = torch.randn(n_samples, 4, h//8, w//8, device=device) conds = [] for b, t, im, s in zip(*inps): print(b, t, im, s) if b == "Image": this_cond = s*get_im_c(im, clip_model) elif b == "Text/URL": if t.startswith("http"): im = get_url_im(t) this_cond = s*get_im_c(im, clip_model) else: this_cond = s*get_txt_c(t, clip_model) else: this_cond = torch.zeros((1, 768), device=device) conds.append(this_cond) conds = torch.cat(conds, dim=0).unsqueeze(0) conds = conds.tile(n_samples, 1, 1) ims = sample(sampler, model, conds, 0*conds, scale, start_code, ddim_steps=steps) # return make_row(ims) # Clear GPU memory cache so less likely to OOM torch.cuda.empty_cache() return ims[0] def is_female(img): pred,pred_idx,probs = gender_learn.predict(img) return float(probs[0]) > float(probs[1]) def has_beard(img): pred,pred_idx,probs = beard_learn.predict(img) return float(probs[1]) > float(probs[0]) def ethnicity(img): pred,pred_idx,probs = ethnic_learn.predict(img) return pred import gradio def boutsify(person): portrait_path = "bouts_m1.jpg" female_detected = is_female(person) ethnicity_prediction = ethnicity(person) if ethnicity_prediction == "Black": print("Colored person") if female_detected: print("This is a female portrait") portrait_path = "bouts_fc1.jpg" else: portrait_path = "bouts_mc1.jpg" else: if female_detected: print("This is a female portrait") portrait_path = "bouts_f1.jpg" else: print("This is a male portrait, checking facial hair") if has_beard(person): print("The person has a beard") portrait_path = "bouts_mb1.jpg" person_image = Image.fromarray(person) inputs = [ "Image", "Image", "Image", "Image", "Nothing", "","","","","", Image.open(portrait_path).convert("RGB"), Image.open("boutsback_o1.png").convert("RGB"), Image.open("boutsback_o2.png").convert("RGB"), person_image, "", 1.1,1,1,1.4,1, 3.0, 1, random.randrange(0, 10000), 50, ] #return person return run_image_mixer(inputs) gradio_interface = gradio.Interface( fn=boutsify, inputs="image", outputs="image", title="Boutsify images", description="Turn portraits into a painting in the style of Flemish master Dirck Bouts", article="© iO Digital" ) gradio_interface.queue().launch()