hysts HF staff commited on
Commit
a1b524b
1 Parent(s): 6080ed9
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "HairCLIP"]
2
+ path = HairCLIP
3
+ url = https://github.com/wty-ustc/HairCLIP
4
+ [submodule "encoder4editing"]
5
+ path = encoder4editing
6
+ url = https://github.com/omertov/encoder4editing
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
HairCLIP ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 29290cf5bdca0f21ff27e0ec2e93bdd1ebbe3605
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+
12
+ if os.getenv('SYSTEM') == 'spaces':
13
+ subprocess.call('git apply ../patch.e4e'.split(), cwd='encoder4editing')
14
+ subprocess.call('git apply ../patch.hairclip'.split(), cwd='HairCLIP')
15
+
16
+ from model import Model
17
+
18
+
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--device', type=str, default='cpu')
22
+ parser.add_argument('--theme', type=str)
23
+ parser.add_argument('--share', action='store_true')
24
+ parser.add_argument('--port', type=int)
25
+ parser.add_argument('--disable-queue',
26
+ dest='enable_queue',
27
+ action='store_false')
28
+ return parser.parse_args()
29
+
30
+
31
+ def load_hairstyle_list() -> list[str]:
32
+ with open('HairCLIP/mapper/hairstyle_list.txt') as f:
33
+ lines = [line.strip() for line in f.readlines()]
34
+ lines = [line[:-10] for line in lines]
35
+ return lines
36
+
37
+
38
+ def set_example_image(example: list) -> dict:
39
+ return gr.Image.update(value=example[0])
40
+
41
+
42
+ def update_step2_components(choice: str) -> tuple[dict, dict]:
43
+ return (
44
+ gr.Dropdown.update(visible=choice in ['hairstyle', 'both']),
45
+ gr.Textbox.update(visible=choice in ['color', 'both']),
46
+ )
47
+
48
+
49
+ def main():
50
+ args = parse_args()
51
+ model = Model(device=args.device)
52
+
53
+ css = '''
54
+ h1#title {
55
+ text-align: center;
56
+ }
57
+ img#teaser {
58
+ max-width: 1000px;
59
+ max-height: 600px;
60
+ }
61
+ '''
62
+
63
+ with gr.Blocks(theme=args.theme, css=css) as demo:
64
+ gr.Markdown('''<h1 id="title">HairCLIP</h1>
65
+
66
+ This is an unofficial demo for <a href="https://github.com/wty-ustc/HairCLIP">https://github.com/wty-ustc/HairCLIP</a>.
67
+
68
+ <center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
69
+ ''')
70
+ with gr.Box():
71
+ gr.Markdown('## Step 1')
72
+ with gr.Row():
73
+ with gr.Column():
74
+ with gr.Row():
75
+ input_image = gr.Image(label='Input Image',
76
+ type='file')
77
+ with gr.Row():
78
+ preprocess_button = gr.Button('Preprocess')
79
+ with gr.Column():
80
+ aligned_face = gr.Image(label='Aligned Face',
81
+ type='pil',
82
+ interactive=False)
83
+ with gr.Column():
84
+ reconstructed_face = gr.Image(label='Reconstructed Face',
85
+ type='numpy')
86
+ latent = gr.Variable()
87
+
88
+ with gr.Row():
89
+ paths = sorted(pathlib.Path('images').glob('*.jpg'))
90
+ example_images = gr.Dataset(components=[input_image],
91
+ samples=[[path.as_posix()]
92
+ for path in paths])
93
+
94
+ with gr.Box():
95
+ gr.Markdown('## Step 2')
96
+ with gr.Row():
97
+ with gr.Column():
98
+ with gr.Row():
99
+ editing_type = gr.Radio(['hairstyle', 'color', 'both'],
100
+ value='both',
101
+ type='value',
102
+ label='Editing Type')
103
+ with gr.Row():
104
+ hairstyles = load_hairstyle_list()
105
+ hairstyle_index = gr.Dropdown(hairstyles,
106
+ value='afro',
107
+ type='index',
108
+ label='Hairstyle')
109
+ with gr.Row():
110
+ color_description = gr.Textbox(value='red',
111
+ label='Color')
112
+ with gr.Row():
113
+ run_button = gr.Button('Run')
114
+
115
+ with gr.Column():
116
+ result = gr.Image(label='Result')
117
+
118
+ gr.Markdown(
119
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.hairclip" alt="visitor badge"/></center>'
120
+ )
121
+
122
+ preprocess_button.click(fn=model.detect_and_align_face,
123
+ inputs=[input_image],
124
+ outputs=[aligned_face])
125
+ aligned_face.change(fn=model.reconstruct_face,
126
+ inputs=[aligned_face],
127
+ outputs=[reconstructed_face, latent])
128
+ editing_type.change(fn=update_step2_components,
129
+ inputs=[editing_type],
130
+ outputs=[hairstyle_index, color_description])
131
+ run_button.click(fn=model.generate,
132
+ inputs=[
133
+ editing_type,
134
+ hairstyle_index,
135
+ color_description,
136
+ latent,
137
+ ],
138
+ outputs=[result])
139
+ example_images.click(fn=set_example_image,
140
+ inputs=example_images,
141
+ outputs=example_images.components)
142
+
143
+ demo.launch(
144
+ enable_queue=args.enable_queue,
145
+ server_port=args.port,
146
+ share=args.share,
147
+ )
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
encoder4editing ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 99ea50578695d2e8a1cf7259d8ee89b23eea942b
images/95UF6LXe-Lo.jpg ADDED

Git LFS Details

  • SHA256: 9ba751a6519822fa683e062ee3a383e748f15b41d4ca87d14c4fa73f9beed845
  • Pointer size: 131 Bytes
  • Size of remote file: 503 kB
images/ILip77SbmOE.jpg ADDED

Git LFS Details

  • SHA256: 3eed82923bc76a90f067415f148d56239fdfa4a1aca9eef1d459bc6050c9dde8
  • Pointer size: 131 Bytes
  • Size of remote file: 939 kB
images/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ These images are freely-usable ones from [Unsplash](https://unsplash.com/).
2
+
3
+ - https://unsplash.com/photos/rDEOVtE7vOs
4
+ - https://unsplash.com/photos/et_78QkMMQs
5
+ - https://unsplash.com/photos/ILip77SbmOE
6
+ - https://unsplash.com/photos/95UF6LXe-Lo
7
+
images/et_78QkMMQs.jpg ADDED

Git LFS Details

  • SHA256: c63a2e9de5eda3cb28012cfc8e4ba9384daeda8cca7a8989ad90b21a1293cc6f
  • Pointer size: 131 Bytes
  • Size of remote file: 371 kB
images/rDEOVtE7vOs.jpg ADDED

Git LFS Details

  • SHA256: b136bf195fef5599f277a563f0eef79af5301d9352d4ebf82bd7a0a061b7bdc0
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from typing import Callable, Union
7
+
8
+ import dlib
9
+ import huggingface_hub
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.transforms as T
15
+
16
+ sys.path.insert(0, 'encoder4editing')
17
+
18
+ from models.psp import pSp
19
+ from utils.alignment import align_face
20
+
21
+ sys.path.insert(0, 'HairCLIP/')
22
+ sys.path.insert(0, 'HairCLIP/mapper/')
23
+
24
+ from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
25
+ from mapper.hairclip_mapper import HairCLIPMapper
26
+
27
+ TOKEN = os.environ['TOKEN']
28
+
29
+
30
+ class Model:
31
+ def __init__(self, device: Union[torch.device, str]):
32
+ self.device = torch.device(device)
33
+ self.landmark_model = self._create_dlib_landmark_model()
34
+ self.e4e = self._load_e4e()
35
+ self.hairclip = self._load_hairclip()
36
+ self.transform = self._create_transform()
37
+
38
+ @staticmethod
39
+ def _create_dlib_landmark_model():
40
+ path = huggingface_hub.hf_hub_download(
41
+ 'hysts/dlib_face_landmark_model',
42
+ 'shape_predictor_68_face_landmarks.dat',
43
+ use_auth_token=TOKEN)
44
+ return dlib.shape_predictor(path)
45
+
46
+ def _load_e4e(self) -> nn.Module:
47
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/e4e',
48
+ 'e4e_ffhq_encode.pt',
49
+ use_auth_token=TOKEN)
50
+ ckpt = torch.load(ckpt_path, map_location='cpu')
51
+ opts = ckpt['opts']
52
+ opts['device'] = self.device.type
53
+ opts['checkpoint_path'] = ckpt_path
54
+ opts = argparse.Namespace(**opts)
55
+ model = pSp(opts)
56
+ model.to(self.device)
57
+ model.eval()
58
+ return model
59
+
60
+ def _load_hairclip(self) -> nn.Module:
61
+ ckpt_path = huggingface_hub.hf_hub_download('hysts/HairCLIP',
62
+ 'hairclip.pt',
63
+ use_auth_token=TOKEN)
64
+ ckpt = torch.load(ckpt_path, map_location='cpu')
65
+ opts = ckpt['opts']
66
+ opts['device'] = self.device.type
67
+ opts['checkpoint_path'] = ckpt_path
68
+ opts['editing_type'] = 'both'
69
+ opts['input_type'] = 'text'
70
+ opts['hairstyle_description'] = 'HairCLIP/mapper/hairstyle_list.txt'
71
+ opts['color_description'] = 'red'
72
+ opts = argparse.Namespace(**opts)
73
+ model = HairCLIPMapper(opts)
74
+ model.to(self.device)
75
+ model.eval()
76
+ return model
77
+
78
+ @staticmethod
79
+ def _create_transform() -> Callable:
80
+ transform = T.Compose([
81
+ T.Resize(256),
82
+ T.CenterCrop(256),
83
+ T.ToTensor(),
84
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
85
+ ])
86
+ return transform
87
+
88
+ def detect_and_align_face(self, image) -> PIL.Image.Image:
89
+ image = align_face(filepath=image.name, predictor=self.landmark_model)
90
+ return image
91
+
92
+ @staticmethod
93
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
94
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
95
+
96
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
97
+ tensor = self.denormalize(tensor)
98
+ return tensor.cpu().numpy().transpose(1, 2, 0)
99
+
100
+ @torch.inference_mode()
101
+ def reconstruct_face(
102
+ self, image: PIL.Image.Image) -> tuple[np.ndarray, torch.Tensor]:
103
+ input_data = self.transform(image).unsqueeze(0).to(self.device)
104
+ reconstructed_images, latents = self.e4e(input_data,
105
+ randomize_noise=False,
106
+ return_latents=True)
107
+ reconstructed = torch.clamp(reconstructed_images[0].detach(), -1, 1)
108
+ reconstructed = self.postprocess(reconstructed)
109
+ return reconstructed, latents[0]
110
+
111
+ @torch.inference_mode()
112
+ def generate(self, editing_type: str, hairstyle_index: int,
113
+ color_description: str, latent: torch.Tensor) -> np.ndarray:
114
+ opts = self.hairclip.opts
115
+ opts.editing_type = editing_type
116
+ opts.color_description = color_description
117
+
118
+ if editing_type == 'color':
119
+ hairstyle_index = 0
120
+
121
+ device = torch.device(opts.device)
122
+
123
+ dataset = LatentsDatasetInference(latents=latent.unsqueeze(0).cpu(),
124
+ opts=opts)
125
+ w, hairstyle_text_inputs_list, color_text_inputs_list = dataset[0][:3]
126
+
127
+ w = w.unsqueeze(0).to(device)
128
+ hairstyle_text_inputs = hairstyle_text_inputs_list[
129
+ hairstyle_index].unsqueeze(0).to(device)
130
+ color_text_inputs = color_text_inputs_list[0].unsqueeze(0).to(device)
131
+
132
+ hairstyle_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
133
+ color_tensor_hairmasked = torch.Tensor([0]).unsqueeze(0).to(device)
134
+
135
+ w_hat = w + 0.1 * self.hairclip.mapper(
136
+ w,
137
+ hairstyle_text_inputs,
138
+ color_text_inputs,
139
+ hairstyle_tensor_hairmasked,
140
+ color_tensor_hairmasked,
141
+ )
142
+ x_hat, _ = self.hairclip.decoder(
143
+ [w_hat],
144
+ input_is_latent=True,
145
+ return_latents=True,
146
+ randomize_noise=False,
147
+ truncation=1,
148
+ )
149
+ res = torch.clamp(x_hat[0].detach(), -1, 1)
150
+ res = self.postprocess(res)
151
+ return res
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cmake
2
+ ninja-build
patch.e4e ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
2
+ index 973a84f..6854b97 100644
3
+ --- a/models/stylegan2/op/fused_act.py
4
+ +++ b/models/stylegan2/op/fused_act.py
5
+ @@ -2,17 +2,18 @@ import os
6
+
7
+ import torch
8
+ from torch import nn
9
+ +from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+ from torch.utils.cpp_extension import load
12
+
13
+ -module_path = os.path.dirname(__file__)
14
+ -fused = load(
15
+ - 'fused',
16
+ - sources=[
17
+ - os.path.join(module_path, 'fused_bias_act.cpp'),
18
+ - os.path.join(module_path, 'fused_bias_act_kernel.cu'),
19
+ - ],
20
+ -)
21
+ +#module_path = os.path.dirname(__file__)
22
+ +#fused = load(
23
+ +# 'fused',
24
+ +# sources=[
25
+ +# os.path.join(module_path, 'fused_bias_act.cpp'),
26
+ +# os.path.join(module_path, 'fused_bias_act_kernel.cu'),
27
+ +# ],
28
+ +#)
29
+
30
+
31
+ class FusedLeakyReLUFunctionBackward(Function):
32
+ @@ -82,4 +83,18 @@ class FusedLeakyReLU(nn.Module):
33
+
34
+
35
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
36
+ - return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
37
+ + if input.device.type == "cpu":
38
+ + if bias is not None:
39
+ + rest_dim = [1] * (input.ndim - bias.ndim - 1)
40
+ + return (
41
+ + F.leaky_relu(
42
+ + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
43
+ + )
44
+ + * scale
45
+ + )
46
+ +
47
+ + else:
48
+ + return F.leaky_relu(input, negative_slope=0.2) * scale
49
+ +
50
+ + else:
51
+ + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
52
+ diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
53
+ index 7bc5a1e..5465d1a 100644
54
+ --- a/models/stylegan2/op/upfirdn2d.py
55
+ +++ b/models/stylegan2/op/upfirdn2d.py
56
+ @@ -1,17 +1,18 @@
57
+ import os
58
+
59
+ import torch
60
+ +from torch.nn import functional as F
61
+ from torch.autograd import Function
62
+ from torch.utils.cpp_extension import load
63
+
64
+ -module_path = os.path.dirname(__file__)
65
+ -upfirdn2d_op = load(
66
+ - 'upfirdn2d',
67
+ - sources=[
68
+ - os.path.join(module_path, 'upfirdn2d.cpp'),
69
+ - os.path.join(module_path, 'upfirdn2d_kernel.cu'),
70
+ - ],
71
+ -)
72
+ +#module_path = os.path.dirname(__file__)
73
+ +#upfirdn2d_op = load(
74
+ +# 'upfirdn2d',
75
+ +# sources=[
76
+ +# os.path.join(module_path, 'upfirdn2d.cpp'),
77
+ +# os.path.join(module_path, 'upfirdn2d_kernel.cu'),
78
+ +# ],
79
+ +#)
80
+
81
+
82
+ class UpFirDn2dBackward(Function):
83
+ @@ -97,8 +98,8 @@ class UpFirDn2d(Function):
84
+
85
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
86
+
87
+ - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
88
+ - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
89
+ + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
90
+ + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
91
+ ctx.out_size = (out_h, out_w)
92
+
93
+ ctx.up = (up_x, up_y)
94
+ @@ -140,9 +141,13 @@ class UpFirDn2d(Function):
95
+
96
+
97
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
98
+ - out = UpFirDn2d.apply(
99
+ - input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
100
+ - )
101
+ + if input.device.type == "cpu":
102
+ + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
103
+ +
104
+ + else:
105
+ + out = UpFirDn2d.apply(
106
+ + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
107
+ + )
108
+
109
+ return out
110
+
111
+ @@ -150,6 +155,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
112
+ def upfirdn2d_native(
113
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
114
+ ):
115
+ + _, channel, in_h, in_w = input.shape
116
+ + input = input.reshape(-1, in_h, in_w, 1)
117
+ +
118
+ _, in_h, in_w, minor = input.shape
119
+ kernel_h, kernel_w = kernel.shape
120
+
121
+ @@ -180,5 +188,9 @@ def upfirdn2d_native(
122
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
123
+ )
124
+ out = out.permute(0, 2, 3, 1)
125
+ + out = out[:, ::down_y, ::down_x, :]
126
+ +
127
+ + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
128
+ + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
129
+
130
+ - return out[:, ::down_y, ::down_x, :]
131
+ + return out.view(-1, channel, out_h, out_w)
patch.hairclip ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/mapper/latent_mappers.py b/mapper/latent_mappers.py
2
+ index 56b9c55..f0dd005 100644
3
+ --- a/mapper/latent_mappers.py
4
+ +++ b/mapper/latent_mappers.py
5
+ @@ -19,7 +19,7 @@ class ModulationModule(Module):
6
+
7
+ def forward(self, x, embedding, cut_flag):
8
+ x = self.fc(x)
9
+ - x = self.norm(x)
10
+ + x = self.norm(x)
11
+ if cut_flag == 1:
12
+ return x
13
+ gamma = self.gamma_function(embedding.float())
14
+ @@ -39,20 +39,20 @@ class SubHairMapper(Module):
15
+ def forward(self, x, embedding, cut_flag=0):
16
+ x = self.pixelnorm(x)
17
+ for modulation_module in self.modulation_module_list:
18
+ - x = modulation_module(x, embedding, cut_flag)
19
+ + x = modulation_module(x, embedding, cut_flag)
20
+ return x
21
+
22
+ -class HairMapper(Module):
23
+ +class HairMapper(Module):
24
+ def __init__(self, opts):
25
+ super(HairMapper, self).__init__()
26
+ self.opts = opts
27
+ - self.clip_model, self.preprocess = clip.load("ViT-B/32", device="cuda")
28
+ + self.clip_model, self.preprocess = clip.load("ViT-B/32", device=opts.device)
29
+ self.transform = transforms.Compose([transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
30
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
31
+ self.hairstyle_cut_flag = 0
32
+ self.color_cut_flag = 0
33
+
34
+ - if not opts.no_coarse_mapper:
35
+ + if not opts.no_coarse_mapper:
36
+ self.course_mapping = SubHairMapper(opts, 4)
37
+ if not opts.no_medium_mapper:
38
+ self.medium_mapping = SubHairMapper(opts, 4)
39
+ @@ -70,13 +70,13 @@ class HairMapper(Module):
40
+ elif hairstyle_tensor.shape[1] != 1:
41
+ hairstyle_embedding = self.gen_image_embedding(hairstyle_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
42
+ else:
43
+ - hairstyle_embedding = torch.ones(x.shape[0], 18, 512).cuda()
44
+ + hairstyle_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
45
+ if color_text_inputs.shape[1] != 1:
46
+ color_embedding = self.clip_model.encode_text(color_text_inputs).unsqueeze(1).repeat(1, 18, 1).detach()
47
+ elif color_tensor.shape[1] != 1:
48
+ color_embedding = self.gen_image_embedding(color_tensor, self.clip_model, self.preprocess).unsqueeze(1).repeat(1, 18, 1).detach()
49
+ else:
50
+ - color_embedding = torch.ones(x.shape[0], 18, 512).cuda()
51
+ + color_embedding = torch.ones(x.shape[0], 18, 512).to(self.opts.device)
52
+
53
+
54
+ if (hairstyle_text_inputs.shape[1] == 1) and (hairstyle_tensor.shape[1] == 1):
55
+ @@ -106,4 +106,4 @@ class HairMapper(Module):
56
+ x_fine = torch.zeros_like(x_fine)
57
+
58
+ out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
59
+ - return out
60
+
61
+ + return out
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ dlib==19.23.0
2
+ numpy==1.22.3
3
+ opencv-python-headless==4.5.5.64
4
+ Pillow==9.1.0
5
+ scipy==1.8.0
6
+ torch==1.11.0
7
+ torchvision==0.12.0
8
+ git+https://github.com/openai/CLIP.git