foivospar commited on
Commit
1c1d081
1 Parent(s): 0259ae0

initial demo

Browse files
Files changed (5) hide show
  1. app.py +231 -0
  2. arc2face/__init__.py +2 -0
  3. arc2face/models.py +91 -0
  4. arc2face/utils.py +30 -0
  5. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ UNet2DConditionModel,
7
+ DPMSolverMultistepScheduler,
8
+ )
9
+
10
+ from arc2face import CLIPTextModelWrapper, project_face_embs
11
+
12
+ import torch
13
+ from insightface.app import FaceAnalysis
14
+ from PIL import Image
15
+ import numpy as np
16
+ import random
17
+
18
+ import gradio as gr
19
+
20
+ # global variable
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ if torch.cuda.is_available():
23
+ device = "cuda"
24
+ dtype = torch.float16
25
+ else:
26
+ device = "cpu"
27
+ dtype = torch.float32
28
+
29
+
30
+ # download models
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/config.json", local_dir="./models")
34
+ hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/diffusion_pytorch_model.safetensors", local_dir="./models")
35
+ hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/config.json", local_dir="./models")
36
+ hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/pytorch_model.bin", local_dir="./models")
37
+ hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arcface.onnx", local_dir="./models/antelopev2")
38
+
39
+ # Load face detection and recognition package
40
+ if device=="cuda":
41
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
42
+ else:
43
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
44
+ app.prepare(ctx_id=0, det_size=(640, 640))
45
+
46
+ # Load pipeline
47
+ base_model = 'runwayml/stable-diffusion-v1-5'
48
+ encoder = CLIPTextModelWrapper.from_pretrained(
49
+ 'models', subfolder="encoder", torch_dtype=dtype
50
+ )
51
+ unet = UNet2DConditionModel.from_pretrained(
52
+ 'models', subfolder="arc2face", torch_dtype=dtype
53
+ )
54
+ pipeline = StableDiffusionPipeline.from_pretrained(
55
+ base_model,
56
+ text_encoder=encoder,
57
+ unet=unet,
58
+ torch_dtype=dtype,
59
+ safety_checker=None
60
+ )
61
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
62
+ pipeline = pipeline.to(device)
63
+
64
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
65
+ if randomize_seed:
66
+ seed = random.randint(0, MAX_SEED)
67
+ return seed
68
+
69
+ def get_example():
70
+ case = [
71
+ [
72
+ './assets/examples/freeman.jpg',
73
+ ],
74
+ [
75
+ './assets/examples/lily.png',
76
+ ],
77
+ [
78
+ './assets/examples/joacquin.png',
79
+ ],
80
+ [
81
+ './assets/examples/jackie.png',
82
+ ],
83
+ [
84
+ './assets/examples/freddie.png',
85
+ ],
86
+ [
87
+ './assets/examples/hepburn.png',
88
+ ],
89
+ ]
90
+ return case
91
+
92
+ def run_example(img_file):
93
+ return generate_image(img_file, 25, 3, 23, 2)
94
+
95
+
96
+ def generate_image(image_path, num_steps, guidance_scale, seed, num_images, progress=gr.Progress(track_tqdm=True)):
97
+
98
+ if image_path is None:
99
+ raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
100
+
101
+ img = np.array(Image.open(image_path))[:,:,::-1]
102
+
103
+ # Face detection and ID-embedding extraction
104
+ faces = app.get(img)
105
+
106
+ if len(faces) == 0:
107
+ raise gr.Error(f"Face detection failed! Please try with another image")
108
+
109
+ faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected)
110
+ id_emb = torch.tensor(faces['embedding'], dtype=dtype)[None].to(device)
111
+ id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding
112
+ id_emb = project_face_embs(pipeline, id_emb) # pass throught the encoder
113
+
114
+ generator = torch.Generator(device=device).manual_seed(seed)
115
+
116
+ print("Start inference...")
117
+ images = pipeline(
118
+ prompt_embeds=id_emb,
119
+ num_inference_steps=num_steps,
120
+ guidance_scale=guidance_scale,
121
+ num_images_per_prompt=num_images,
122
+ generator=generator
123
+ ).images
124
+
125
+ return images
126
+
127
+ ### Description
128
+ title = r"""
129
+ <h1>Arc2Face: A Foundation Model of Human Faces</h1>
130
+ """
131
+
132
+ description = r"""
133
+ <b>Official 🤗 Gradio demo</b> for <a href='https://arc2face.github.io/' target='_blank'><b>Arc2Face: A Foundation Model of Human Faces</b></a>.<br>
134
+
135
+ Steps:<br>
136
+ 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin.
137
+ 2. Click <b>Submit</b> to generate new images of the subject.
138
+ """
139
+
140
+ Footer = r"""
141
+ ---
142
+ 📝 **Citation**
143
+ <br>
144
+ If you find Arc2Face helpful for your research, please consider citing our paper:
145
+ ```bibtex
146
+ @misc{paraperas2024arc2face,
147
+ title={Arc2Face: A Foundation Model of Human Faces},
148
+ author={Foivos Paraperas Papantoniou and Alexandros Lattas and Stylianos Moschoglou and Jiankang Deng and Bernhard Kainz and Stefanos Zafeiriou},
149
+ year={2024},
150
+ eprint={2403.11641},
151
+ archivePrefix={arXiv},
152
+ primaryClass={cs.CV}
153
+ }
154
+ ```
155
+ """
156
+
157
+ css = '''
158
+ .gradio-container {width: 85% !important}
159
+ '''
160
+ with gr.Blocks(css=css) as demo:
161
+
162
+ # description
163
+ gr.Markdown(title)
164
+ gr.Markdown(description)
165
+
166
+ with gr.Row():
167
+ with gr.Column():
168
+
169
+ # upload face image
170
+ img_file = gr.Image(label="Upload a photo with a face", type="filepath")
171
+
172
+ submit = gr.Button("Submit", variant="primary")
173
+
174
+ with gr.Accordion(open=False, label="Advanced Options"):
175
+ num_steps = gr.Slider(
176
+ label="Number of sample steps",
177
+ minimum=20,
178
+ maximum=100,
179
+ step=1,
180
+ value=25,
181
+ )
182
+ guidance_scale = gr.Slider(
183
+ label="Guidance scale",
184
+ minimum=0.1,
185
+ maximum=10.0,
186
+ step=0.1,
187
+ value=3,
188
+ )
189
+ num_images = gr.Slider(
190
+ label="Number of output images",
191
+ minimum=1,
192
+ maximum=4,
193
+ step=1,
194
+ value=2,
195
+ )
196
+ seed = gr.Slider(
197
+ label="Seed",
198
+ minimum=0,
199
+ maximum=MAX_SEED,
200
+ step=1,
201
+ value=0,
202
+ )
203
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
204
+
205
+ with gr.Column():
206
+ gallery = gr.Gallery(label="Generated Images")
207
+
208
+ submit.click(
209
+ fn=randomize_seed_fn,
210
+ inputs=[seed, randomize_seed],
211
+ outputs=seed,
212
+ queue=False,
213
+ api_name=False,
214
+ ).then(
215
+ fn=generate_image,
216
+ inputs=[img_file, num_steps, guidance_scale, seed, num_images],
217
+ outputs=[gallery]
218
+ )
219
+
220
+
221
+ gr.Examples(
222
+ examples=get_example(),
223
+ inputs=[img_file],
224
+ run_on_click=True,
225
+ fn=run_example,
226
+ outputs=[gallery],
227
+ )
228
+
229
+ gr.Markdown(Footer)
230
+
231
+ demo.launch()
arc2face/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models import CLIPTextModelWrapper
2
+ from .utils import project_face_embs
arc2face/models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPTextModel
3
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
5
+ from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
6
+
7
+
8
+ class CLIPTextModelWrapper(CLIPTextModel):
9
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
10
+ # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
11
+ def forward(
12
+ self,
13
+ input_ids: Optional[torch.Tensor] = None,
14
+ attention_mask: Optional[torch.Tensor] = None,
15
+ position_ids: Optional[torch.Tensor] = None,
16
+ output_attentions: Optional[bool] = None,
17
+ output_hidden_states: Optional[bool] = None,
18
+ return_dict: Optional[bool] = None,
19
+ input_token_embs: Optional[torch.Tensor] = None,
20
+ return_token_embs: Optional[bool] = False,
21
+ ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
22
+
23
+ if return_token_embs:
24
+ return self.text_model.embeddings.token_embedding(input_ids)
25
+
26
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27
+
28
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
29
+ output_hidden_states = (
30
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
31
+ )
32
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
33
+
34
+ if input_ids is None:
35
+ raise ValueError("You have to specify input_ids")
36
+
37
+ input_shape = input_ids.size()
38
+ input_ids = input_ids.view(-1, input_shape[-1])
39
+
40
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
41
+
42
+ # CLIP's text model uses causal mask, prepare it here.
43
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
44
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
45
+ # expand attention_mask
46
+ if attention_mask is not None:
47
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
48
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
49
+
50
+ encoder_outputs = self.text_model.encoder(
51
+ inputs_embeds=hidden_states,
52
+ attention_mask=attention_mask,
53
+ causal_attention_mask=causal_attention_mask,
54
+ output_attentions=output_attentions,
55
+ output_hidden_states=output_hidden_states,
56
+ return_dict=return_dict,
57
+ )
58
+
59
+ last_hidden_state = encoder_outputs[0]
60
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
61
+
62
+ if self.text_model.eos_token_id == 2:
63
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
64
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
65
+ # ------------------------------------------------------------
66
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
67
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
68
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
69
+ pooled_output = last_hidden_state[
70
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
71
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
72
+ ]
73
+ else:
74
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
75
+ pooled_output = last_hidden_state[
76
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
77
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
78
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
79
+ .int()
80
+ .argmax(dim=-1),
81
+ ]
82
+
83
+ if not return_dict:
84
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
85
+
86
+ return BaseModelOutputWithPooling(
87
+ last_hidden_state=last_hidden_state,
88
+ pooler_output=pooled_output,
89
+ hidden_states=encoder_outputs.hidden_states,
90
+ attentions=encoder_outputs.attentions,
91
+ )
arc2face/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ @torch.no_grad()
5
+ def project_face_embs(pipeline, face_embs):
6
+
7
+ '''
8
+ face_embs: (N, 512) normalized ArcFace embeddings
9
+ '''
10
+
11
+ arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0]
12
+
13
+ input_ids = pipeline.tokenizer(
14
+ "photo of a id person",
15
+ truncation=True,
16
+ padding="max_length",
17
+ max_length=pipeline.tokenizer.model_max_length,
18
+ return_tensors="pt",
19
+ ).input_ids.to(pipeline.device)
20
+
21
+ face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0)
22
+ token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True)
23
+ token_embs[input_ids==arcface_token_id] = face_embs_padded
24
+
25
+ prompt_embeds = pipeline.text_encoder(
26
+ input_ids=input_ids,
27
+ input_token_embs=token_embs
28
+ )[0]
29
+
30
+ return prompt_embeds
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ numpy<1.24.0
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ diffusers==0.22.0
5
+ transformers==4.34.1
6
+ accelerate
7
+ insightface
8
+ onnxruntime-gpu
9
+ gradio