File size: 5,061 Bytes
9c39b3b
 
 
 
0f9d2e0
9c39b3b
 
edc894e
9c39b3b
 
 
 
 
942ca44
 
edc894e
 
 
72adce4
dd7f377
2fc1ecf
 
edc894e
2fc1ecf
 
 
 
 
 
 
 
 
98a2cec
 
edc894e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98a2cec
 
dd7f377
 
98a2cec
dd7f377
 
 
9a7567c
98a2cec
dd7f377
98a2cec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39b3b
 
 
 
 
b03281e
5ac0b57
2fc1ecf
9c39b3b
 
 
 
 
 
 
64406bc
2fce01c
 
64406bc
 
 
 
 
 
9c39b3b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import cv2
import tempfile
import inspect
from typing import List, Optional, Union
import os
import numpy as np
import torch
import banana_dev as banana
import PIL
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import gradio as gr
import random
import base64
from io import BytesIO
import os
from PIL import Image
import face_recognition
import whatimage
import pyheif

def decodeImage(bytesIo, filename):
  fmt = whatimage.identify_image(bytesIo)
  print(fmt)
  if fmt in ['heic', 'avif']:
    i = pyheif.read_heif(bytesIo)
    pi = Image.frombytes(
          mode=i.mode, size=i.size, data=i.data)
    print(pi)
    pi.save(filename, format="jpeg")
  
def inpaint(p, init_image, mask_image=None, strength=0.75, guidance_scale=7.5, generator=None, num_samples=1, n_iter=1):
    buffered_init_img = BytesIO()
    buffered_inverted_img = BytesIO()
    init_image.save(buffered_init_img,format="JPEG")
    mask_image.save(buffered_inverted_img,format="JPEG")
    encoded_init_image = base64.b64encode(buffered_init_img.getvalue()).decode('utf-8')
    encoded_inverted_image = base64.b64encode(buffered_inverted_img.getvalue()).decode('utf-8')
    model_inputs = {
      "prompt": "4K UHD professional profile picture of a person wearing a suit for work and posing for a picture, fine details, realistic shaded.",
      "init_image": encoded_init_image,
      "mask_image": encoded_inverted_image,
      "strength": 0.65,
      "guidance_scale": 10,
      "num_inference_steps": 100
    }
    out = banana.run(os.environ.get("API_KEY"), os.environ.get("MODEL_KEY"), model_inputs)
    image_byte_string = out["modelOutputs"][0]["output_image_base64"]
    image_encoded = image_byte_string.encode('utf-8')
    image_bytes = BytesIO(base64.b64decode(image_encoded))
    return_image = Image.open(image_bytes)
    return return_image
    
def identify_face(user_image):
  # img = cv2.imread(user_image.name) # read the resized image in cv2
  img = face_recognition.load_image_file(user_image.name)
  print(img.shape)
  face_locations = face_recognition.face_locations(img)
  for face_location in face_locations: 
    top, right, bottom, left = face_location
    mask = np.zeros(img.shape[:2], dtype="uint8")
    print(mask.shape)
    cv2.rectangle(mask, (left, top), (right, bottom),  255, -1)
    inverted_image = cv2.bitwise_not(mask)
    return inverted_image

def sample_images(init_image, mask_image):
  p = "4K UHD professional profile picture of a person wearing a suit for work"
  strength=0.65
  guidance_scale=10
  num_samples = 1
  n_iter = 1

  generator = torch.Generator(device="cuda").manual_seed(random.randint(0, 1000000)) # change the seed to get different results
  all_images = inpaint(p, init_image, mask_image, strength=strength, guidance_scale=guidance_scale, generator=generator, num_samples=num_samples, n_iter=n_iter)
  return all_images

def preprocess_image(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def preprocess_mask(mask):
    mask=mask.convert("L")
    w, h = mask.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    mask = mask.resize((w//8, h//8), resample=PIL.Image.NEAREST)
    mask = np.array(mask).astype(np.float32) / 255.0
    mask = np.tile(mask,(4,1,1))
    mask = mask[None].transpose(0, 1, 2, 3)#what does this step do?
    mask = 1 - mask #repaint white, keep black
    mask = torch.from_numpy(mask)
    return mask

# accept an image input 
# trigger the set of functions to occur => identify face, generate mask, save the inverted face mask, sample for the inverted images
# output the sampled images
def main(user_image):
  # accept the image as input
  with open(user_image.name, 'rb') as f:
    data = f.read()
  decodeImage(data, user_image.name)
  init_image = PIL.Image.open(user_image).convert("RGB")
  # # resize the image to be (512, 512)
  newsize = (512, 512)
  init_image = init_image.resize(newsize)
  init_image.save(user_image.name) # save the resized image
  ## identify the face + save the inverted mask
  inverted_mask = identify_face(user_image)
  if inverted_mask == None:
    print("init_image:", type(init_image))
    return init_image
  print("inverted_mask: ", inverted_mask)
  fp = tempfile.NamedTemporaryFile(mode='wb', suffix=".png") 
  cv2.imwrite(fp.name, inverted_mask) # save the inverted image 
  pil_inverted_mask = PIL.Image.open(fp.name).convert("RGB")
  # sample the new 
  return sample_images(init_image, pil_inverted_mask)

demo = gr.Interface(main, gr.Image(type="file"), "image")
demo.launch(debug=True)