mikonvergence commited on
Commit
aca81a2
1 Parent(s): 21863d4

main src files

Browse files
Files changed (5) hide show
  1. src/__init__.py +0 -0
  2. src/detection.py +54 -0
  3. src/masking.py +89 -0
  4. src/process.py +36 -0
  5. src/synthesis.py +53 -0
src/__init__.py ADDED
File without changes
src/detection.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import kornia as K
4
+ from kornia.core import Tensor
5
+ from kornia.contrib import FaceDetector, FaceDetectorResult, FaceKeypoint
6
+
7
+ print('Loading Face Detector...')
8
+ face_detection = FaceDetector()
9
+ print('DONE')
10
+
11
+ def detect_face(input):
12
+
13
+ # Preprocessing
14
+ img = K.image_to_tensor(np.array(input), keepdim=False)
15
+ img = K.color.bgr_to_rgb(img.float())
16
+
17
+ with torch.no_grad():
18
+ dets = face_detection(img)
19
+
20
+ return [FaceDetectorResult(o) for o in dets[0]]
21
+
22
+ def process_face(dets):
23
+ vis_threshold = 0.8
24
+ faces = []
25
+ hairs = []
26
+
27
+ for b in dets:
28
+ if b.score < vis_threshold:
29
+ continue
30
+
31
+ reye_kpt=b.get_keypoint(FaceKeypoint.EYE_RIGHT).int().tolist()
32
+ leye_kpt=b.get_keypoint(FaceKeypoint.EYE_LEFT).int().tolist()
33
+ rmou_kpt=b.get_keypoint(FaceKeypoint.MOUTH_RIGHT).int().tolist()
34
+ lmou_kpt=b.get_keypoint(FaceKeypoint.MOUTH_LEFT).int().tolist()
35
+ nose_kpt=b.get_keypoint(FaceKeypoint.NOSE).int().tolist()
36
+
37
+ faces.append([nose_kpt,
38
+ rmou_kpt,
39
+ lmou_kpt,
40
+ reye_kpt,
41
+ leye_kpt
42
+ ])
43
+
44
+ # point above
45
+ top=((b.top_right + b.top_left)/2).int().tolist()
46
+ bot=((b.bottom_right + b.bottom_left)/2).int().tolist()
47
+ face_h = np.abs(top[1]-bot[1])
48
+ top_margin=[top[0], top[1]-face_h*0.1]
49
+
50
+ hairs.append([
51
+ top_margin
52
+ ])
53
+
54
+ return faces, hairs
src/masking.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from kornia.morphology import dilation, closing
3
+ import requests
4
+ from transformers import SamModel, SamProcessor
5
+
6
+ print('Loading SAM...')
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
9
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
10
+ print('DONE')
11
+
12
+ def build_mask(image, faces, hairs):
13
+
14
+ # 1. Segmentation
15
+ input_points = faces # 2D location of the face
16
+
17
+ with torch.no_grad():
18
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
19
+ outputs = model(**inputs)
20
+
21
+ masks = processor.image_processor.post_process_masks(
22
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
23
+ )
24
+ scores = outputs.iou_scores
25
+
26
+ input_points = hairs # 2D location of the face
27
+
28
+ with torch.no_grad():
29
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
30
+ outputs = model(**inputs)
31
+
32
+ h_masks = processor.image_processor.post_process_masks(
33
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
34
+ )
35
+ h_scores = outputs.iou_scores
36
+
37
+ # 2. Post-processing
38
+ mask=masks[0][0].all(0) | h_masks[0][0].all(0)
39
+
40
+ # dilation
41
+ tensor = mask[None,None,:,:]
42
+ kernel = torch.ones(3, 3)
43
+ mask = closing(tensor, kernel)[0,0].bool()
44
+
45
+ return mask
46
+
47
+ def build_mask_multi(image, faces, hairs):
48
+
49
+ all_masks = []
50
+
51
+ for face,hair in zip(faces,hairs):
52
+ # 1. Segmentation
53
+ input_points = [face] # 2D location of the face
54
+
55
+ with torch.no_grad():
56
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
57
+ outputs = model(**inputs)
58
+
59
+ masks = processor.image_processor.post_process_masks(
60
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
61
+ )
62
+ scores = outputs.iou_scores
63
+
64
+ input_points = [hair] # 2D location of the face
65
+
66
+ with torch.no_grad():
67
+ inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
68
+ outputs = model(**inputs)
69
+
70
+ h_masks = processor.image_processor.post_process_masks(
71
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
72
+ )
73
+ h_scores = outputs.iou_scores
74
+
75
+ # 2. Post-processing
76
+ mask=masks[0][0].all(0) | h_masks[0][0].all(0)
77
+
78
+ # dilation
79
+ mask_T = mask[None,None,:,:]
80
+ kernel = torch.ones(3, 3)
81
+ mask = closing(mask_T, kernel)[0,0].bool()
82
+
83
+ all_masks.append(mask)
84
+
85
+ mask = all_masks[0]
86
+ for next_mask in all_masks[1:]:
87
+ mask = mask | next_mask
88
+
89
+ return mask
src/process.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .detection import *
8
+ from .masking import *
9
+ from .synthesis import *
10
+
11
+ def forward(image_cam, image_upload, prompt="", n_prompt=None, num_steps=20, seed=0, original_resolution=False):
12
+
13
+ if image_cam is None:
14
+ image = image_upload
15
+ else:
16
+ image = image_cam
17
+
18
+ if not original_resolution:
19
+ w,h = image.size
20
+ ratio = 512/h
21
+ new_size = int(w*ratio), int(h*ratio)
22
+ image = image.resize(new_size)
23
+
24
+ # detect face
25
+ dets = detect_face(image)
26
+
27
+ # segment hair and face
28
+ faces, hairs = process_face(dets)
29
+
30
+ # build mask
31
+ mask = build_mask_multi(image, faces, hairs)
32
+
33
+ # synthesise
34
+ new_image = synthesis(image,mask, prompt, n_prompt, num_steps=num_steps, seed=seed)
35
+
36
+ return new_image
src/synthesis.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from controlnet_aux import OpenposeDetector
2
+ from diffusers import StableDiffusionInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
3
+ from src.ControlNetInpaint.src.pipeline_stable_diffusion_controlnet_inpaint import *
4
+ from kornia.filters import gaussian_blur2d
5
+
6
+ if not 'controlnet' in globals():
7
+ print('Loading ControlNet...')
8
+ controlnet = ControlNetModel.from_pretrained(
9
+ "fusing/stable-diffusion-v1-5-controlnet-openpose", torch_dtype=torch.float16
10
+ )
11
+
12
+ if 'pipe' not in globals():
13
+ print('Loading SD...')
14
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
16
+ ).to('cuda')
17
+ print('DONE')
18
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
19
+
20
+ if 'openpose' not in globals():
21
+ print('Loading OpenPose...')
22
+ openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
23
+ print('DONE')
24
+
25
+ def synthesis(image, mask, prompt="", n_prompt="", num_steps=20, seed=0, remix=True):
26
+
27
+ # 1. Get pose
28
+ with torch.no_grad():
29
+ pose_image = openpose(image)
30
+ pose_image=pose_image.resize(image.size)
31
+
32
+ # generate image
33
+ generator = torch.manual_seed(seed)
34
+ new_image = pipe(
35
+ prompt,
36
+ negative_prompt = n_prompt,
37
+ generator=generator,
38
+ num_inference_steps=num_steps,
39
+ image=image,
40
+ control_image=pose_image,
41
+ mask_image=(mask==False).float().numpy(),
42
+ ).images
43
+
44
+ if remix:
45
+ for idx in range(len(new_image)):
46
+ mask = gaussian_blur2d(1.0*mask[None,None,:,:],
47
+ kernel_size=(11, 11),
48
+ sigma=(29, 29)
49
+ ).squeeze().clip(0,1)
50
+
51
+ new_image[idx] = (mask[:,:,None]*np.asarray(image) + (1-mask[:,:,None])*np.asarray(new_image[idx].resize(image.size))).int().numpy()
52
+
53
+ return new_image