Raphael commited on
Commit
154c8c5
1 Parent(s): c558ab6

Signed-off-by: Raphael <oOraph@users.noreply.github.com>

Files changed (2) hide show
  1. app.py +141 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
+ from diffusers import StableDiffusionPipeline
7
+ import gradio as gr
8
+ import mediapipe as mp
9
+ import numpy as np
10
+ import PIL
11
+ import torch.cuda
12
+ # from transformers import pipeline
13
+
14
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ force=True)
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+ LOG.info("Loading image segmentation model")
24
+
25
+ # seg_kwargs = {
26
+ # "task": "image-segmentation",
27
+ # "model": "nvidia/segformer-b0-finetuned-ade-512-512"
28
+ # }
29
+ #
30
+ # img_segmentation = pipeline(**seg_kwargs)
31
+
32
+
33
+ mp_selfie_segmentation = mp.solutions.selfie_segmentation
34
+ img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
35
+
36
+
37
+ LOG.info("Loading diffusion model")
38
+
39
+ diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
40
+
41
+ if torch.cuda.is_available():
42
+ LOG.info("Moving diffusion model to GPU")
43
+ diffusion.to('cuda')
44
+
45
+
46
+ def image_preprocess(image: PIL.Image):
47
+ LOG.info("Preprocessing image %s", image)
48
+ start = time.time()
49
+ # image = PIL.ImageOps.exif_transpose(image)
50
+ image = image.convert("RGB")
51
+ image = resize_image(image)
52
+ image = np.array(image)
53
+ # Convert RGB to BGR
54
+ image = image[:, :, ::-1].copy()
55
+ elapsed = time.time() - start
56
+ LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
57
+ return image
58
+
59
+
60
+ def resize_image(image: PIL.Image):
61
+ width, height = image.size
62
+ ratio = max(width / 512, height / 512)
63
+ width = int(width / ratio) // 8 * 8
64
+ height = int(height / ratio) // 8 * 8
65
+ image = image.resize((width, height))
66
+ return image
67
+
68
+
69
+ def extract_selfie_mask(threshold, image):
70
+ LOG.info("Extracting selfie mask")
71
+ start = time.time()
72
+ image = img_segmentation_model.process(image)
73
+ mask = image.segmentation_mask
74
+ cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
75
+ cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
76
+ cv2.blur(mask, (10, 10), dst=mask)
77
+
78
+ elapsed = time.time() - start
79
+ LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
80
+ return mask
81
+
82
+
83
+ def generate_background(prompt, num_inference_steps, height, width):
84
+ LOG.info("Generating background")
85
+ start = time.time()
86
+ background = diffusion(
87
+ prompt=prompt,
88
+ num_inference_steps=int(num_inference_steps),
89
+ height=height,
90
+ width=width
91
+ )
92
+ nsfw = background.nsfw_content_detected[0]
93
+ background = background.images[0]
94
+
95
+ if nsfw:
96
+ LOG.info('NSFW detected, skipping')
97
+ background = np.zeros((height, width), dtype='uint8')
98
+ else:
99
+ background = np.array(background)
100
+ # Convert RGB to BGR
101
+ background = background[:, :, ::-1].copy()
102
+
103
+ elapsed = time.time() - start
104
+ LOG.info("Background generated, elapsed %.2f seconds", elapsed)
105
+ return background
106
+
107
+
108
+ def merge_selfie_and_background(selfie, background, mask):
109
+ LOG.info("Merging extracted selfie and generated background")
110
+ cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
111
+ selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
112
+ selfie = PIL.Image.fromarray(selfie)
113
+ return selfie
114
+
115
+
116
+ def demo(threshold, image, prompt, num_inference_steps):
117
+ image = image_preprocess(image)
118
+ mask = extract_selfie_mask(threshold, image)
119
+ background = generate_background(prompt, num_inference_steps,
120
+ image.shape[0], image.shape[1])
121
+ output = merge_selfie_and_background(image, background, mask)
122
+ return output
123
+
124
+
125
+ iface = gr.Interface(
126
+ fn=demo,
127
+ inputs=[
128
+ gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
129
+ value=0.8),
130
+ gr.Image(type='pil', label="Upload your selfie"),
131
+ gr.Text(value="a photo of the Eiffel tower on the right side",
132
+ label="Background description"),
133
+ gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
134
+ value=50)
135
+ ],
136
+ outputs=[
137
+ gr.Image(label="Invent yourself a life :)")
138
+ ])
139
+
140
+ # iface.launch(server_name="0.0.0.0", server_port=6443)
141
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ pillow
4
+ timm
5
+ mediapipe
6
+ diffusers
7
+ transformers
8
+ scipy
9
+ ftfy
10
+ accelerate
11
+ torch
12
+ numpy