miguelmuzo commited on
Commit
3d5b2b1
·
verified ·
1 Parent(s): 3de0e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -70
app.py CHANGED
@@ -1,151 +1,394 @@
1
  import hashlib
2
  import os
3
  from io import BytesIO
 
4
 
5
  import gradio as gr
6
- import grpc
7
  from PIL import Image
8
  from cachetools import LRUCache
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- from inference_pb2 import HairSwapRequest, HairSwapResponse
11
- from inference_pb2_grpc import HairSwapServiceStub
12
  from utils.shape_predictor import align_face
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_bytes(img):
 
16
  if img is None:
17
  return img
18
-
19
  buffered = BytesIO()
20
  img.save(buffered, format="JPEG")
21
  return buffered.getvalue()
22
 
23
 
24
  def bytes_to_image(image: bytes) -> Image.Image:
 
25
  image = Image.open(BytesIO(image))
26
  return image
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def center_crop(img):
 
30
  width, height = img.size
31
  side = min(width, height)
32
-
33
  left = (width - side) / 2
34
  top = (height - side) / 2
35
  right = (width + side) / 2
36
  bottom = (height + side) / 2
37
-
38
  img = img.crop((left, top, right, bottom))
39
  return img
40
 
41
 
42
  def resize(name):
 
43
  def resize_inner(img, align):
44
  global align_cache
45
-
46
  if name in align:
47
  img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
-
49
  if img_hash not in align_cache:
50
- img = align_face(img, return_tensors=False)[0]
51
- align_cache[img_hash] = img
 
 
 
 
 
52
  else:
53
  img = align_cache[img_hash]
54
-
55
  elif img.size != (1024, 1024):
56
  img = center_crop(img)
57
  img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
-
59
  return img
60
-
61
  return resize_inner
62
 
63
 
64
- def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
 
 
 
 
 
 
 
 
 
 
 
 
65
  if not face and not shape and not color:
66
- return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)
 
 
 
67
  elif not face:
68
- return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True)
 
 
 
69
  elif not shape and not color:
70
- return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True)
71
-
72
- face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))
73
-
74
- if shape_bytes is None:
75
- shape_bytes = b'face'
76
- if color_bytes is None:
77
- color_bytes = b'shape'
78
-
79
- with grpc.insecure_channel(os.environ['SERVER']) as channel:
80
- stub = HairSwapServiceStub(channel)
81
-
82
- output: HairSwapResponse = stub.swap(
83
- HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
84
- poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
85
  )
86
 
87
- output = bytes_to_image(output.image)
88
- return gr.update(value=output, visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  def get_demo():
92
- with gr.Blocks() as demo:
93
- gr.Markdown("## HairFastGan")
94
- gr.Markdown(
95
- '<div style="display: flex; align-items: center; gap: 10px;">'
96
- '<span>Official HairFastGAN Gradio demo:</span>'
97
- '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
98
- '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
99
- '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
100
- '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
101
- '</div>'
102
- )
 
 
103
  with gr.Row():
104
  with gr.Column():
105
- source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
 
 
 
 
 
 
106
  with gr.Row():
107
- shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
108
- color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
109
- with gr.Accordion("Advanced Options", open=False):
110
- blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
111
- label="Color Encoder version", info="Selects a model for hair color transfer.")
112
- poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
113
- info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
114
- poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
115
- info="Smooths out the blending area.")
116
- align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
117
- label="Image cropping [recommended]",
118
- info="Selects which images to crop by face")
119
- btn = gr.Button("Get the haircut")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  with gr.Column():
121
- output = gr.Image(label="Your result")
122
- error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
123
-
124
- gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
125
- ["input/10.jpg", None, "input/11.jpg"]],
126
- inputs=[source, shape, color], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
128
  source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
129
  shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
130
  color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
131
 
132
- btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
133
- outputs=[output, error_message])
 
 
 
 
134
 
135
- gr.Markdown('''To cite the paper by the authors
136
- ```
 
 
137
  @article{nikolaev2024hairfastgan,
138
  title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
139
  author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
140
  journal={arXiv preprint arXiv:2404.01094},
141
  year={2024}
142
  }
143
- ```
144
  ''')
 
145
  return demo
146
 
147
 
148
  if __name__ == '__main__':
 
149
  align_cache = LRUCache(maxsize=10)
 
 
150
  demo = get_demo()
151
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
1
  import hashlib
2
  import os
3
  from io import BytesIO
4
+ import base64
5
 
6
  import gradio as gr
 
7
  from PIL import Image
8
  from cachetools import LRUCache
9
+ import torch
10
+ import numpy as np
11
+
12
+ # Direct HairFast imports (no gRPC needed!)
13
+ try:
14
+ from hair_swap import HairFast, get_parser
15
+ HAIRFAST_AVAILABLE = True
16
+ print("✅ HairFast successfully imported!")
17
+ except ImportError as e:
18
+ print(f"❌ HairFast import failed: {e}")
19
+ HAIRFAST_AVAILABLE = False
20
 
 
 
21
  from utils.shape_predictor import align_face
22
 
23
+ # Global variables
24
+ hair_fast_model = None
25
+ align_cache = LRUCache(maxsize=10)
26
+
27
+
28
+ def initialize_hairfast():
29
+ """Initialize HairFast model"""
30
+ global hair_fast_model
31
+
32
+ if not HAIRFAST_AVAILABLE:
33
+ print("❌ HairFast not available")
34
+ return False
35
+
36
+ try:
37
+ print("🔄 Initializing HairFast model...")
38
+
39
+ # Get default arguments
40
+ parser = get_parser()
41
+ args = parser.parse_args([]) # Use default arguments
42
+
43
+ # Override some settings for HF Spaces
44
+ args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ args.batch_size = 1 # Keep small for HF Spaces
46
+
47
+ # Initialize HairFast
48
+ hair_fast_model = HairFast(args)
49
+
50
+ print(f"✅ HairFast initialized successfully on {args.device}!")
51
+ return True
52
+
53
+ except Exception as e:
54
+ print(f"❌ HairFast initialization failed: {e}")
55
+ hair_fast_model = None
56
+ return False
57
+
58
 
59
  def get_bytes(img):
60
+ """Convert PIL Image to bytes"""
61
  if img is None:
62
  return img
 
63
  buffered = BytesIO()
64
  img.save(buffered, format="JPEG")
65
  return buffered.getvalue()
66
 
67
 
68
  def bytes_to_image(image: bytes) -> Image.Image:
69
+ """Convert bytes to PIL Image"""
70
  image = Image.open(BytesIO(image))
71
  return image
72
 
73
 
74
+ def base64_to_image(base64_string):
75
+ """Convert base64 string to PIL Image"""
76
+ try:
77
+ if base64_string.startswith('data:image'):
78
+ base64_string = base64_string.split(',')[1]
79
+ image_bytes = base64.b64decode(base64_string)
80
+ return Image.open(BytesIO(image_bytes))
81
+ except Exception as e:
82
+ print(f"Error converting base64 to image: {e}")
83
+ return None
84
+
85
+
86
+ def image_to_base64(image):
87
+ """Convert PIL Image to base64 string"""
88
+ if image is None:
89
+ return None
90
+ buffered = BytesIO()
91
+ image.save(buffered, format="JPEG")
92
+ img_bytes = buffered.getvalue()
93
+ img_base64 = base64.b64encode(img_bytes).decode('utf-8')
94
+ return f"data:image/jpeg;base64,{img_base64}"
95
+
96
+
97
+ def pil_to_tensor(image):
98
+ """Convert PIL to tensor for HairFast"""
99
+ if isinstance(image, Image.Image):
100
+ # Convert to tensor format expected by HairFast
101
+ image_array = np.array(image)
102
+ if image_array.max() > 1:
103
+ image_array = image_array / 255.0
104
+ tensor = torch.from_numpy(image_array).permute(2, 0, 1).float()
105
+ return tensor
106
+ return image
107
+
108
+
109
+ def tensor_to_pil(tensor):
110
+ """Convert tensor to PIL Image"""
111
+ if isinstance(tensor, torch.Tensor):
112
+ if tensor.dim() == 4:
113
+ tensor = tensor.squeeze(0)
114
+ if tensor.dim() == 3:
115
+ tensor = tensor.permute(1, 2, 0)
116
+ tensor = tensor.detach().cpu().numpy()
117
+ if tensor.max() <= 1:
118
+ tensor = (tensor * 255).astype(np.uint8)
119
+ return Image.fromarray(tensor)
120
+ return tensor
121
+
122
+
123
  def center_crop(img):
124
+ """Center crop image to square"""
125
  width, height = img.size
126
  side = min(width, height)
 
127
  left = (width - side) / 2
128
  top = (height - side) / 2
129
  right = (width + side) / 2
130
  bottom = (height + side) / 2
 
131
  img = img.crop((left, top, right, bottom))
132
  return img
133
 
134
 
135
  def resize(name):
136
+ """Image resize function with face alignment"""
137
  def resize_inner(img, align):
138
  global align_cache
139
+
140
  if name in align:
141
  img_hash = hashlib.md5(get_bytes(img)).hexdigest()
 
142
  if img_hash not in align_cache:
143
+ try:
144
+ img = align_face(img, return_tensors=False)[0]
145
+ align_cache[img_hash] = img
146
+ except Exception as e:
147
+ print(f"Face alignment failed: {e}")
148
+ img = center_crop(img)
149
+ img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
150
  else:
151
  img = align_cache[img_hash]
 
152
  elif img.size != (1024, 1024):
153
  img = center_crop(img)
154
  img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
155
+
156
  return img
 
157
  return resize_inner
158
 
159
 
160
+ def swap_hair_direct(face, shape, color, blending, poisson_iters, poisson_erosion):
161
+ """Direct hair swapping using HairFast (no gRPC)"""
162
+ global hair_fast_model
163
+
164
+ # Initialize model if needed
165
+ if hair_fast_model is None:
166
+ if not initialize_hairfast():
167
+ return gr.update(visible=False), gr.update(
168
+ value="❌ HairFast model not available. Please check if all model files are uploaded.",
169
+ visible=True
170
+ )
171
+
172
+ # Validation
173
  if not face and not shape and not color:
174
+ return gr.update(visible=False), gr.update(
175
+ value="Need to upload a face and at least a shape or color ❗",
176
+ visible=True
177
+ )
178
  elif not face:
179
+ return gr.update(visible=False), gr.update(
180
+ value="Need to upload a face ❗",
181
+ visible=True
182
+ )
183
  elif not shape and not color:
184
+ return gr.update(visible=False), gr.update(
185
+ value="Need to upload at least a shape or color ❗",
186
+ visible=True
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
 
189
+ try:
190
+ print("🔄 Starting hair transfer...")
191
+
192
+ # Use shape as color if color is not provided
193
+ if color is None:
194
+ color = shape
195
+ if shape is None:
196
+ shape = color
197
+
198
+ # Direct HairFast inference
199
+ result_tensor = hair_fast_model.swap(
200
+ face_img=face,
201
+ shape_img=shape,
202
+ color_img=color,
203
+ benchmark=False,
204
+ align=True, # Use face alignment
205
+ seed=3407
206
+ )
207
+
208
+ # Convert result tensor to PIL Image
209
+ result_image = tensor_to_pil(result_tensor)
210
+
211
+ print("✅ Hair transfer completed successfully!")
212
+ return gr.update(value=result_image, visible=True), gr.update(visible=False)
213
+
214
+ except Exception as e:
215
+ error_msg = f"❌ Hair transfer failed: {str(e)}"
216
+ print(error_msg)
217
+ return gr.update(visible=False), gr.update(value=error_msg, visible=True)
218
+
219
+
220
+ def hair_transfer_api(source_image, shape_image=None, color_image=None,
221
+ blending="Article", poisson_iters=0, poisson_erosion=15):
222
+ """API function for React integration"""
223
+ global hair_fast_model
224
+
225
+ try:
226
+ # Handle base64 inputs
227
+ if isinstance(source_image, str):
228
+ source_image = base64_to_image(source_image)
229
+ if isinstance(shape_image, str):
230
+ shape_image = base64_to_image(shape_image)
231
+ if isinstance(color_image, str):
232
+ color_image = base64_to_image(color_image)
233
+
234
+ # Initialize model if needed
235
+ if hair_fast_model is None:
236
+ if not initialize_hairfast():
237
+ return None, "❌ HairFast model not available"
238
+
239
+ # Validation
240
+ if source_image is None:
241
+ return None, "❌ Source image is required"
242
+
243
+ # Use source as reference if no references provided
244
+ if shape_image is None and color_image is None:
245
+ return None, "❌ At least shape or color reference image is required"
246
+
247
+ if color_image is None:
248
+ color_image = shape_image
249
+ if shape_image is None:
250
+ shape_image = color_image
251
+
252
+ # Direct HairFast inference
253
+ result_tensor = hair_fast_model.swap(
254
+ face_img=source_image,
255
+ shape_img=shape_image,
256
+ color_img=color_image,
257
+ benchmark=False,
258
+ align=True,
259
+ seed=3407
260
+ )
261
+
262
+ # Convert to PIL and then base64
263
+ result_image = tensor_to_pil(result_tensor)
264
+ result_base64 = image_to_base64(result_image)
265
+
266
+ return result_base64, "✅ Hair transfer completed successfully!"
267
+
268
+ except Exception as e:
269
+ error_msg = f"❌ API Error: {str(e)}"
270
+ print(error_msg)
271
+ return None, error_msg
272
 
273
 
274
  def get_demo():
275
+ """Create Gradio interface"""
276
+ with gr.Blocks(
277
+ title="HairFastGAN Direct API",
278
+ theme=gr.themes.Soft()
279
+ ) as demo:
280
+
281
+ gr.HTML("""
282
+ <div style="text-align: center; padding: 20px;">
283
+ <h1>🎨 HairFastGAN - Direct Model Inference</h1>
284
+ <p>High-quality hair transfer without gRPC dependency</p>
285
+ </div>
286
+ """)
287
+
288
  with gr.Row():
289
  with gr.Column():
290
+ gr.HTML("<h3>📤 Input Images</h3>")
291
+
292
+ source = gr.Image(
293
+ label="Source Photo (Person's Face)",
294
+ type="pil"
295
+ )
296
+
297
  with gr.Row():
298
+ shape = gr.Image(
299
+ label="Hair Shape Reference (Optional)",
300
+ type="pil"
301
+ )
302
+ color = gr.Image(
303
+ label="Hair Color Reference (Optional)",
304
+ type="pil"
305
+ )
306
+
307
+ with gr.Accordion("🔧 Advanced Options", open=False):
308
+ blending = gr.Radio(
309
+ ["Article", "Alternative_v1", "Alternative_v2"],
310
+ value='Article',
311
+ label="Color Encoder Version"
312
+ )
313
+ poisson_iters = gr.Slider(
314
+ 0, 2500, value=0, step=1,
315
+ label="Poisson Iterations",
316
+ info="Detail recovery strength"
317
+ )
318
+ poisson_erosion = gr.Slider(
319
+ 1, 100, value=15, step=1,
320
+ label="Poisson Erosion",
321
+ info="Blending smoothness"
322
+ )
323
+ align = gr.CheckboxGroup(
324
+ ["Face", "Shape", "Color"],
325
+ value=["Face", "Shape", "Color"],
326
+ label="Face Alignment [Recommended]"
327
+ )
328
+
329
+ btn = gr.Button("🎨 Transfer Hair Style", variant="primary", size="lg")
330
+
331
  with gr.Column():
332
+ gr.HTML("<h3>📥 Result</h3>")
333
+
334
+ output = gr.Image(label="Result Image", type="pil")
335
+ error_message = gr.Textbox(
336
+ label="⚠️ Status",
337
+ visible=False,
338
+ elem_classes="error-message"
339
+ )
340
+
341
+ # Example gallery
342
+ gr.HTML("<h3>💡 Examples</h3>")
343
+ gr.Examples(
344
+ examples=[
345
+ ["input/0.png", "input/1.png", "input/2.png"],
346
+ ["input/6.png", "input/7.png", None],
347
+ ["input/10.jpg", None, "input/11.jpg"]
348
+ ],
349
+ inputs=[source, shape, color],
350
+ outputs=output
351
+ )
352
 
353
+ # Event handlers
354
  source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
355
  shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
356
  color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
357
 
358
+ btn.click(
359
+ fn=swap_hair_direct,
360
+ inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
361
+ outputs=[output, error_message],
362
+ api_name="predict" # For React integration
363
+ )
364
 
365
+ # Citation
366
+ gr.Markdown('''
367
+ ### 📖 Citation
368
+ ```bibtex
369
  @article{nikolaev2024hairfastgan,
370
  title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
371
  author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
372
  journal={arXiv preprint arXiv:2404.01094},
373
  year={2024}
374
  }
375
+ ```
376
  ''')
377
+
378
  return demo
379
 
380
 
381
  if __name__ == '__main__':
382
+ # Initialize cache
383
  align_cache = LRUCache(maxsize=10)
384
+
385
+ # Create demo
386
  demo = get_demo()
387
+
388
+ # Launch with API enabled
389
+ demo.launch(
390
+ server_name="0.0.0.0",
391
+ server_port=7860,
392
+ show_api=True,
393
+ share=False
394
+ )