diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..01d55b69296c54b2ade5660b79bee6fadf4d6089 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +**__pycache__/ + +MODELS +third_party +tmp +results +chat_anything/tts_vits/ +vits_results +test +resources/models.yaml + +# others +GFPGANv1.4.pth +gfpgan +GFPGAN +.gitattributes \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..03fee62ba5f60d70ba39eac0ada6f51b0f961fcd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel + +# FROM python:3.9 + +# WORKDIR /code + +# COPY ./requirements.txt /code/requirements.txt + +# RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt + +# for open cv +RUN apt-get update && apt-get install libgl1 -y + +RUN useradd -m -u 1000 user + +USER user + +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR $HOME/ChatAnything + +COPY --chown=user . $HOME/ChatAnything + +RUN pip install -r requirements.txt + +CMD python app.py diff --git a/README.md b/README.md index 9e84064b01a2ae6d3e30c7a323f91d698fdc1466..235dd4b75efb9a7e57ecf6436e4261e779b00ccb 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,141 @@ pinned: false --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# ChatAnything: Facetime Chat with LLM-Enhanced Personas + +**Yilin Zhao\*, Shanghua Gao\*, Daquan Zhou\*, Xinbin Yuan\*, Zhijie Lin, Qibin Hou, Jiashi Feng** + + + +> What will it be like to Facetime any imaginary concepts? +To animate anything, we integrated current open-source models at hand for an animation application for interactive AI-Agent chatting usage. +> +> To start with, take a look at these incredible faces generated with open-source Civitai models that are to be animated. +drawing + + +Here we provide you with ChatAnything. A simple pipeline Enhanced with currently limitless Large Language Models, yielding imaginary Facetime chats with intented visual appearance! + +Remember, the repo and application are totally based on pre-trained deep learning methods and haven't included any training yet. We give all the credit to the open-source community (shout out to you). For detail of the pipeline, see our technical report (TODO: link here) +## Release & Features & Future Plans + +- [ ] Fine-tune face rendering module. +- [ ] Better TTS module & voice render module. +- [ ] Adding Open-source Language Models. +- [x] Initial release + - Facetime Animation. + - Multiple model choices for initial frame generation. + - Multiple choices for voices. +# Install & Run +Just follow the instructions. Every thing would be simple (hopefully). Reach out if you met with any problems! +### Install +first, install the virtual environment. +``` +conda env create -f environment.yaml + +# then install +conda env update --name chatanything --file environment.yaml +``` + +The Pipeline integrated Open-Source Models. All Models are to be found online(see [Acknowledgement](#acknowledgement)). We put some important models together on huggingface remotes just to make life easier. Prepare them for the first run with this Python script [prepare_models.py](./python_scripts/prepare_models.py): +``` +# prepare the local models +python python_scripts/prepare_models.py + +``` + +### Building Docker +Try build a docker if you find it easier. This part is not fully tested. If you find a anything wrong, feel free to contribute~ +``` +docker build --network=host -t chatanything . +# docker run -dp 127.0.0.1:8901:8901 chatanything +docker run -p 127.0.0.1:8901:8901 -it --gpus all chatanything +docker run -it --gpus all chatanything bash +``` + +### Run +specify a port for the gradio application to run on and set off! +``` +PORT=8809 python app.py $PORT +``` + +# Configuring: From User Input Concept to Appearance & Voice +The first step of the pipeline is to generate a image for SadTalker and at the same time set up the Text to Sound Module for voice chat. + +The pipeline would query a powerful LLM (ChatGPT) for the selection in a zero-shot multi-choice selection format. +Three Questions are asked upon the initial of every conversation(init frame generation): +1. Provide a imagen personality for the user input concept. +2. Select a Generative model for the init frame generation. +3. Select a Text To Sound Voice(Model) for the character base on the personality. + +We have constructed the model selection to be extendable. Add your ideal model with just a few lines of Configuring! The rest of this section would breifly introduce the steps to add a init-frame generator/language voice. + +### Image Generator +Configure the models in the [Model Config](./resources/models.yaml). This Config acts as the memory (or an image-generating tool pool) for the LLM. + +The prompt sets up this selection process. Each sub field of the "models" would turn into an option in the multiple-choice question. +the "**desc**" field of each element is what the Language Model would see. The key is not provided to the LM as it would sometimes mislead it. +the others are used for the image generation as listed: +1. model_dir: the repo-path for diffusers package. As the pretrained Face-landmark ControlNet is based on stable-diffusion-v1-5, we currently only supports the derivatives of it. +2. lora_path: LoRA derivatives are powerful, try a LoRA model also for better stylization. Should directly point to the parameters binary file. +3. prompt_template & negative_prompt: this is used for prompting the text-to-image diffusion model. Find a ideal prompt for your model and stick with it. A "{}" should be in the prompt template for inserting the user input concept. + +Here are some **Tips** for configuring you own model. +1. Provide the LLM with a simple description of the generative model. It is worth noting that the description needs to be concise and accurate for a correct selection. +2. Set the model_dir to a local directory of diffusers stable-diffusion-v1-5 derivatives. Also, you can provide a repo-id on the huggingface hub model space. The model would be downloaded when first chosen, wait for it. +3. To better utilize the resources from the community, we also add in support of the LoRA features. To add the LoRA module, you would need to give the path to the parameter files. + +4. Carefully write the prompt template and negative prompt. These which affect the initial face generation a lot. Be aware that the prompt template should contain only one pair of "{}" to insert the concept that users wrote on the application webpage. We support the Stable-Diffusion-Webui prompt style as implemented by diffusers, feel free to copy the prompt from Civitai for better prompting the generation and put in the "{}" to the original prompt for ChatAnything! + +Again, this model's config acts as an extended tool pool for the LM, the application would drive the LM to choose from this config and use the chosen model to generate. Sometimes the LM fails to choose the correct model or choosing any available model, this would cause the Chatanything app to fail on a generation. + +Notice we currently support ONLY stable-diffusion-v1.5 derivatives (Sdxl Pipelines are under consideration, however not yet implemented as we lack a face-landmark ControlNet for it. Reach out if you're interested in training one!) + +### Voice TTS +We are using the edge_tts package for text-to-speech support. The voice selection and [voice configuration file](./resources/voices_edge.yaml) is constructed similarly to the Image generation model selection, except now the LM is supposed to choose the voice base on the personality description given by itself earlier. "**gender**" and "**language**" field corresponds to edge_tts. + +# On-going tasks. +### Customized Voice. +There is a Voice Changer TextToSpeach-SpeachVoiceConversion Pipeline app, which ensures a better customized voice. We are trying to leverage its TTS functionality. + +Reach out if you want to add a voice of your own or your hero! + +Here are the possible steps for +You would need to change a little bit in the code first: +1. Alter this [code](./utils.py#14) to import a TTSTalker from chat_anything/tts_talker/tts_voicechanger.py. +2. switch the config to another one, change [code](./utils.py#14) "resources/voices_edge.yaml" -> "resources/voices_voicechanger.yaml" + +The try running a [Voice Changer](https://huggingface.co/spaces/kevinwang676/Voice-Changer) on your local machine. Simply set up git-lfs and install the repo and run it for the TTS voice service. +The TTS caller was set to port 7860. + +make sure the client class is set up with the same port in [here](chat_anything/tts_talker/tts_voicechanger.py#5) +```python +client = Client("http://127.0.0.1:7860/") +``` + +# Acknowledgement +Again, the project hasn't yet included any training. The pipeline is totally based on these incredible awesome packages and pretrained models. Don't hesitate to take a look and explore the amazing open-source generative communities. We love you, guys. +- [ChatGPT](https://openai.com/chatgpt): GOD +- [SadTalker](https://github.com/OpenTalker/SadTalker): The Core Animation Module +- [Face-Landmark-ControlNet](https://huggingface.co/georgefen/Face-Landmark-ControlNet): An Awesome ControlNet with Face landmark using Stable Diffusion 1.5 as base Model. +- [diffusers](https://github.com/huggingface/diffusers): GOAT of Image Generative Framework🥳. +- [langchain](https://github.com/langchain-ai/langchain): An Awesome Package for Dealing with LLM. +- [edge-tts](https://github.com/rany2/edge-tts): An Awesome Package for Text To Sound Solutions. +- [gradio](https://www.gradio.app/): GOAT😄 Machine Learning based App framework. +- [Civitai](https://civitai.com/models) and [Huggingface_hub](https://huggingface.co/models): Find your ideal Image Generative Model on Civitai. These Communities are Crazy🥂. Here are Some Fantastic Derivatives of [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5): + - [Game Icon Institute_mode](https://civitai.com/models/47800?modelVersionId=76533) + - [dreamshaper](https://civitai.com/models/4384/dreamshaper) + - [3D_Animation_Diffusion](https://civitai.com/models/118086?modelVersionId=128046) + - [anything-v5](https://huggingface.co/stablediffusionapi/anything-v5) + +# Citation +If you like our pipeline and application, don't hesitate to reach out! Let's work on it and see how far it would go! +```bibtex +@misc{zhao2023ChatAnything, + title={ChatAnything: Facetime Chat with LLM-Enhanced Personas}, + author={Yilin, Zhao and Shanghua, Gao and Daquan, Zhou and Xinbin, Yuan and Qibin, Hou and Jiashi, Feng}, + publisher={}, + year={2023}, +} +``` + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d9c1818557284c4f2ac4ee61b56aa58a153f9494 --- /dev/null +++ b/app.py @@ -0,0 +1,239 @@ +import os +import ssl +import sys + +import gradio as gr + +import warnings +import whisper +from chat_anything.polly_utils import PollyVoiceData +from chat_anything.azure_utils import AzureVoiceData +from chat_anything.chatbot.chat import set_openai_api_key +from utils import ChatWrapper, update_foo, reset_memory + +ssl._create_default_https_context = ssl._create_unverified_context + + +TALKING_HEAD_WIDTH = "350" + +LOOPING_TALKING_HEAD = "resources/videos/tempfile.mp4" + +USE_GPT4_DEFAULT = False +FULLBODY_DEFAULT = False +POLLY_VOICE_DATA = PollyVoiceData() +AZURE_VOICE_DATA = AzureVoiceData() + +# Pertains to WHISPER functionality +WHISPER_DETECT_LANG = "Detect language" + +INSTRUCTION_MARKDOWN = """ +# ChatAnything: Facetime Chat with LLM-Enhanced Personas +### DEMO INSTRUCTION +##### 0. Register +Input a OpenAI API Key of your own. This would be used to chat with openai-chatgpt. Make sure to disable the key afterwards🥹. +##### 1. Generate The init face😀 along with first chat +Input a Concept in the "Talking object" text box, then click on Generate button. The init face generation and module selection will be performed and used for the rest of this chat. Wait for a while and the video would be produced and played. Write simple concept for generating. The concept will be place on each prompt template for deciding the main concepts. +##### 2. Keep on Chatting🤑 +Go on speak with the character. The init face and module selection will not reperform itself, now you are only chatting with the LM, along with the rendering of sadtalker. Hopefully, the API will not impose an excessive charge for this. + + +### FEATURES +##### 1. Upload a image for control/inversion starting point. Try some none face images and see how it works! +##### 2. seeding is provided. However if not providing a input image, there would be a random chosen facial landmark image for generating, which might include some randomness. +##### 3. Try out the examples. +##### 4. Say something and recorded your voice for a real facetime chat. Whisper will handle your voice, see setting-Whisper STT options. +##### 5. Decide whether to use the crop face out option, this will crop out the face from the generated image and render. This is promising for better animation rendering, however sometimes the croped image loses some elementary features of you intended concept. + +""" + +# UNCOMMENT TO USE WHISPER +warnings.filterwarnings("ignore") +WHISPER_MODEL = whisper.load_model("tiny") +print("WHISPER_MODEL", WHISPER_MODEL) + + +# UNCOMMENT TO USE WHISPER +def transcribe(aud_inp, whisper_lang): + if aud_inp is None: + return "" + aud = whisper.load_audio(aud_inp) + aud = whisper.pad_or_trim(aud) + mel = whisper.log_mel_spectrogram(aud).to(WHISPER_MODEL.device) + _, probs = WHISPER_MODEL.detect_language(mel) + options = whisper.DecodingOptions() + if whisper_lang != WHISPER_DETECT_LANG: + whisper_lang_code = POLLY_VOICE_DATA.get_whisper_lang_code( + whisper_lang) + options = whisper.DecodingOptions(language=whisper_lang_code) + result = whisper.decode(WHISPER_MODEL, mel, options) + print("result.text", result.text) + result_text = "" + if result and result.text: + result_text = result.text + return result_text + + +chat = ChatWrapper() + + +with gr.Blocks() as block: + llm_state = gr.State() + history_state = gr.State() + chain_state = gr.State() + talker_state = gr.State() + fullbody_state = gr.State(True) + speak_text_state = gr.State(True) + talking_head_state = gr.State(True) + uid_state = gr.State() + video_file_path = gr.State() + audio_file_path = gr.State() + + memory_state = gr.State() + + + # Pertains to WHISPER functionality + whisper_lang_state = gr.State(WHISPER_DETECT_LANG) + use_gpt4_state = gr.State(USE_GPT4_DEFAULT) + + with gr.Column(): + with gr.Row(): + gr.Markdown(INSTRUCTION_MARKDOWN) + with gr.Row(): + openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter", + show_label=True, lines=1, type='password', value='', label='OpenAI API key') + openai_api_key_register = gr.Button( + value="Register").style(full_width=False) + uid_textbox = gr.Textbox(show_label=True, value=uid_state, lines=1, label='UID') + seed = gr.Slider( + label="Seed", + minimum=-1, + maximum=2147483647, + step=1, + randomize=True, + ) + + with gr.Tab("Chat"): + with gr.Row(): + with gr.Column(scale=1, min_width=TALKING_HEAD_WIDTH, visible=True): + with gr.Column(): + class_prompt = gr.Textbox( + 'apple', + default='apple', + type="text", label='Talking object' + ) + init_face_btn = gr.Button( + value="Generate").style(full_width=False) + + my_file = gr.File(label="Upload a file", + type="file", visible=False) + + # video_html = gr.HTML('') + video_html = gr.Video(label="Generated Video", autoplay=True) + + ref_image = gr.Image( + type="pil", + interactive=True, + label="Image: Upload your image.", + ) + tmp_aud_file = gr.File( + type="file", visible=False) + audio_html = gr.HTML('') + init_face_btn.click(chat.generate_init_face_video, inputs=[class_prompt, llm_state, uid_state,fullbody_state, ref_image, seed], + outputs=[chain_state, memory_state, video_html,talker_state]) + + + with gr.Column(scale=7): + chatbot = gr.Chatbot() + + + message = gr.Textbox(label="What's on your mind??", + placeholder="What's the answer to life, the universe, and everything?", + lines=1) + submit = gr.Button(value="Send", variant="secondary").style( + full_width=False) + + audio_comp = gr.Microphone(source="microphone", type="filepath", label="Just say it!", + interactive=True, streaming=False) + audio_comp.change(transcribe, inputs=[ + audio_comp, whisper_lang_state], outputs=[message]) + + + with gr.Accordion("General examples", open=False): + gr.Examples( + examples=[ + ["cyberpunk godess", "Who are you?", "resources/images/annie.jpg", 393212389], + ["unbelievable beauty fairy", "Who are you?", "resources/images/lenna.jpg", 222679277], + ["tree monster", "Who are you?", None], + ["pineapple monster", "Who are you?", None], + ["tricky Polaris", "Who are you?", None, 1670155100], + ["watermelon", "Who are you?", "resources/images/watermelon.jpg", 42], + ], + inputs=[class_prompt, message, ref_image, seed], + ) + + with gr.Tab("Settings"): + with gr.Tab("General"): + + talking_head_cb = gr.Checkbox( + label="Show talking head", value=True) + talking_head_cb.change(chat.update_talking_head, inputs=[talking_head_cb, uid_state, talking_head_state], + outputs=[talking_head_state, video_html]) + + use_gpt4_cb = gr.Checkbox(label="Use GPT-4 (experimental) if your OpenAI API has access to it", + value=USE_GPT4_DEFAULT) + + fullbody_state = gr.Checkbox(label="Use full body instead of a face.", + value=True) + + use_gpt4_cb.change(set_openai_api_key, + inputs=[openai_api_key_textbox, + use_gpt4_cb], + outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path]) + + reset_btn = gr.Button(value="Reset chat", + variant="secondary").style(full_width=False) + reset_btn.click(reset_memory, inputs=[history_state, memory_state], + outputs=[chatbot, history_state, memory_state]) + + + with gr.Tab("Whisper STT"): + whisper_lang_radio = gr.Radio(label="Whisper speech-to-text language:", choices=[ + WHISPER_DETECT_LANG, "Arabic", "Arabic (Gulf)", "Catalan", "Chinese (Cantonese)", "Chinese (Mandarin)", + "Danish", "Dutch", "English (Australian)", "English (British)", "English (Indian)", "English (New Zealand)", + "English (South African)", "English (US)", "English (Welsh)", "Finnish", "French", "French (Canadian)", + "German", "German (Austrian)", "Georgian", "Hindi", "Icelandic", "Indonesian", "Italian", "Japanese", + "Korean", "Norwegian", "Polish", + "Portuguese (Brazilian)", "Portuguese (European)", "Romanian", "Russian", "Spanish (European)", + "Spanish (Mexican)", "Spanish (US)", "Swedish", "Turkish", "Ukrainian", "Welsh"], + value=WHISPER_DETECT_LANG) + + whisper_lang_radio.change(update_foo, + inputs=[whisper_lang_radio, + whisper_lang_state], + outputs=[whisper_lang_state]) + + gr.HTML(""" +

This application is based on Chat-GPT-LangChain, LangChain +

""") + + message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, + speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state], + outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message]) + + submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, + speak_text_state, talking_head_state, uid_state,talker_state,fullbody_state], + outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message]) + + openai_api_key_register.click(set_openai_api_key, + inputs=[openai_api_key_textbox, + use_gpt4_state, chatbot], + outputs=[llm_state, use_gpt4_state, chatbot, uid_state, video_file_path, audio_file_path]) + +if __name__ == "__main__": + import sys + if len(sys.argv) == 1: + port = 8901 + else: + port = int(sys.argv[1]) + block.launch(debug=True, server_name="0.0.0.0", + server_port=port, share=True, enable_queue = True) diff --git a/chat_anything/azure_utils.py b/chat_anything/azure_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4173eaa689abe9b7b6b66ed3fcf1ede591655a53 --- /dev/null +++ b/chat_anything/azure_utils.py @@ -0,0 +1,155 @@ +# This class stores Azure voice data. Specifically, the class stores several records containing +# language, lang_code, gender, voice_id and engine. The class also has a method to return the +# voice_id, lang_code and engine given a language and gender. + +NEURAL_ENGINE = "neural" +STANDARD_ENGINE = "standard" + + +class AzureVoiceData: + def get_voice(self, language, gender): + for voice in self.voice_data: + if voice['language'] == language and voice['gender'] == gender: + return voice['azure_voice'] + return None + + def __init__(self): + self.voice_data = [ + {'language': 'Arabic', + 'azure_voice': 'ar-EG-ShakirNeural', + 'gender': 'Male'}, + {'language': 'Arabic (Gulf)', + 'azure_voice': 'ar-KW-FahedNeural', + 'gender': 'Male'}, + {'language': 'Catalan', + 'azure_voice': 'ca-ES-EnricNeural', + 'gender': 'Male'}, + {'language': 'Chinese (Cantonese)', + 'azure_voice': 'yue-CN-YunSongNeural', + 'gender': 'Male'}, + {'language': 'Chinese (Mandarin)', + 'azure_voice': 'zh-CN-YunxiNeural', + 'gender': 'Male'}, + {'language': 'Danish', + 'azure_voice': 'da-DK-JeppeNeural', + 'gender': 'Male'}, + {'language': 'Dutch', + 'azure_voice': 'nl-NL-MaartenNeural', + 'gender': 'Male'}, + {'language': 'English (Australian)', + 'azure_voice': 'en-AU-KenNeural', + 'gender': 'Male'}, + {'language': 'English (British)', + 'azure_voice': 'en-GB-RyanNeural', + 'gender': 'Male'}, + {'language': 'English (Indian)', + 'azure_voice': 'en-IN-PrabhatNeural', + 'gender': 'Male'}, + {'language': 'English (New Zealand)', + 'azure_voice': 'en-NZ-MitchellNeural', + 'gender': 'Male'}, + {'language': 'English (South African)', + 'azure_voice': 'en-ZA-LukeNeural', + 'gender': 'Male'}, + {'language': 'English (US)', + 'azure_voice': 'en-US-ChristopherNeural', + 'gender': 'Male'}, + {'language': 'English (Welsh)', + 'azure_voice': 'cy-GB-AledNeural', + 'gender': 'Male'}, + {'language': 'Finnish', + 'azure_voice': 'fi-FI-HarriNeural', + 'gender': 'Male'}, + {'language': 'French', + 'azure_voice': 'fr-FR-HenriNeural', + 'gender': 'Male'}, + {'language': 'French (Canadian)', + 'azure_voice': 'fr-CA-AntoineNeural', + 'gender': 'Male'}, + {'language': 'German', + 'azure_voice': 'de-DE-KlausNeural', + 'gender': 'Male'}, + {'language': 'German (Austrian)', + 'azure_voice': 'de-AT-JonasNeural', + 'gender': 'Male'}, + {'language': 'Hindi', + 'azure_voice': 'hi-IN-MadhurNeural', + 'gender': 'Male'}, + {'language': 'Icelandic', + 'azure_voice': 'is-IS-GunnarNeural', + 'gender': 'Male'}, + {'language': 'Italian', + 'azure_voice': 'it-IT-GianniNeural', + 'gender': 'Male'}, + {'language': 'Japanese', + 'azure_voice': 'ja-JP-KeitaNeural', + 'gender': 'Male'}, + {'language': 'Korean', + 'azure_voice': 'ko-KR-GookMinNeural', + 'gender': 'Male'}, + {'language': 'Norwegian', + 'azure_voice': 'nb-NO-FinnNeural', + 'gender': 'Male'}, + {'language': 'Polish', + 'azure_voice': 'pl-PL-MarekNeural', + 'gender': 'Male'}, + {'language': 'Portuguese (Brazilian)', + 'azure_voice': 'pt-BR-NicolauNeural', + 'gender': 'Male'}, + {'language': 'Portuguese (European)', + 'azure_voice': 'pt-PT-DuarteNeural', + 'gender': 'Male'}, + {'language': 'Romanian', + 'azure_voice': 'ro-RO-EmilNeural', + 'gender': 'Male'}, + {'language': 'Russian', + 'azure_voice': 'ru-RU-DmitryNeural', + 'gender': 'Male'}, + {'language': 'Spanish (European)', + 'azure_voice': 'es-ES-TeoNeural', + 'gender': 'Male'}, + {'language': 'Spanish (Mexican)', + 'azure_voice': 'es-MX-LibertoNeural', + 'gender': 'Male'}, + {'language': 'Spanish (US)', + 'azure_voice': 'es-US-AlonsoNeural"', + 'gender': 'Male'}, + {'language': 'Swedish', + 'azure_voice': 'sv-SE-MattiasNeural', + 'gender': 'Male'}, + {'language': 'Turkish', + 'azure_voice': 'tr-TR-AhmetNeural', + 'gender': 'Male'}, + {'language': 'Welsh', + 'azure_voice': 'cy-GB-AledNeural', + 'gender': 'Male'}, + ] + + +# Run from the command-line +if __name__ == '__main__': + azure_voice_data = AzureVoiceData() + + azure_voice = azure_voice_data.get_voice('English (US)', 'Male') + print('English (US)', 'Male', azure_voice) + + azure_voice = azure_voice_data.get_voice('English (US)', 'Female') + print('English (US)', 'Female', azure_voice) + + azure_voice = azure_voice_data.get_voice('French', 'Female') + print('French', 'Female', azure_voice) + + azure_voice = azure_voice_data.get_voice('French', 'Male') + print('French', 'Male', azure_voice) + + azure_voice = azure_voice_data.get_voice('Japanese', 'Female') + print('Japanese', 'Female', azure_voice) + + azure_voice = azure_voice_data.get_voice('Japanese', 'Male') + print('Japanese', 'Male', azure_voice) + + azure_voice = azure_voice_data.get_voice('Hindi', 'Female') + print('Hindi', 'Female', azure_voice) + + azure_voice = azure_voice_data.get_voice('Hindi', 'Male') + print('Hindi', 'Male', azure_voice) diff --git a/chat_anything/chatbot/__init__.py b/chat_anything/chatbot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/chatbot/chat.py b/chat_anything/chatbot/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..69d505db421c71e11abfadb8fcb74928c1bef804 --- /dev/null +++ b/chat_anything/chatbot/chat.py @@ -0,0 +1,72 @@ +import datetime +from chat_anything.chatbot.personality import generate_personality_prompt +from langchain.prompts import PromptTemplate +from langchain import ConversationChain +from langchain.chains.conversation.memory import ConversationBufferMemory +from langchain.chat_models import ChatOpenAI +from langchain.embeddings.openai import OpenAIEmbeddings +import os +import random +import string + + +def load_chain(llm, class_concept=None): + chain = None + memory = None + personality_text = None + print(llm) + if llm: + print("class_concept", class_concept) + if class_concept is None: + class_concept = 'AI assistant' + person_template, personality_text = generate_personality_prompt(llm, class_concept) + + PROMPT_TEMPLATE = PromptTemplate( + input_variables=["history", "input"], + template=person_template, + ) + + chain = ConversationChain( + prompt=PROMPT_TEMPLATE, + llm=llm, + verbose=False, + memory=ConversationBufferMemory(ai_prefix="You"), + ) + print("New concept done for ", class_concept) + + return chain, memory, personality_text + + + +def set_openai_api_key(api_key, use_gpt4, history=None, max_tokens=1024): + """Set the api key and return chain. + If no api_key, then None is returned. + """ + if api_key and api_key.startswith("sk-") and len(api_key) > 50: + os.environ["OPENAI_API_KEY"] = api_key + print("\n\n ++++++++++++++ Setting OpenAI API key ++++++++++++++ \n\n") + print(str(datetime.datetime.now()) + ": Before OpenAI, OPENAI_API_KEY length: " + str( + len(os.environ["OPENAI_API_KEY"]))) + + if use_gpt4: + llm = ChatOpenAI( + temperature=0, max_tokens=max_tokens, model_name="gpt-4") + print("Trying to use llm ChatOpenAI with gpt-4") + else: + print("Trying to use llm ChatOpenAI with gpt-3.5-turbo") + llm = ChatOpenAI(temperature=0, max_tokens=max_tokens, + model_name="gpt-3.5-turbo") + + print(str(datetime.datetime.now()) + ": After OpenAI, OPENAI_API_KEY length: " + str( + len(os.environ["OPENAI_API_KEY"]))) + + print(str(datetime.datetime.now()) + ": After load_chain, OPENAI_API_KEY length: " + str( + len(os.environ["OPENAI_API_KEY"]))) + os.environ["OPENAI_API_KEY"] = "" + history = history or [] + history.append(['', '[SYSTEM] OPENAI_API_KEY has been set, you can generate your object and talk to it now!']) + uid = ''.join(random.sample(string.ascii_lowercase + string.ascii_uppercase, 5)) + video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4') + audio_file_path = os.path.join('tmp', uid, 'audio/tempfile.mp3') + return llm, use_gpt4, history, uid, video_file_path, audio_file_path + return None, None, None, None, None, None \ No newline at end of file diff --git a/chat_anything/chatbot/model_select.py b/chat_anything/chatbot/model_select.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4cddae584bf85752c0569778e83e8c0e364dd1 --- /dev/null +++ b/chat_anything/chatbot/model_select.py @@ -0,0 +1,60 @@ +from langchain import LLMChain +from langchain.prompts import PromptTemplate +from omegaconf import OmegaConf +import datetime + +MODEL_SELECTION_PROMPT_TEMPLATE = """ +Select one of the following models based on the given concept. +You must choose one model name based on the description of each model and the concept! + +Cencept: {concept} + +Model name and description: {model_list} + +Warning: {warning} + +The avilable model names: +{model_name_list} + +Selected model name: +""" + +def load_model_list(): + models_config = OmegaConf.load('resources/models.yaml') + models_dict = models_config['models'] + model_name_list_str = '' + print(models_dict) + model_list_str = '' + for key, value in models_dict.items(): + model_list_str+="model name: " +key+', model description: '+value['desc']+'\n' + model_name_list_str += key + ' ' + model_name_list_str += '\n' + return model_list_str, models_dict, model_name_list_str + +def model_selection_chain(llm, class_concept=None): + chain = None + memory = None + if llm: + print("class_concept", class_concept) + if class_concept is None: + class_concept = 'AI assistant' + + + template = PromptTemplate( + input_variables=["model_list", "concept", "warning", "model_name_list"], + template=MODEL_SELECTION_PROMPT_TEMPLATE, + ) + model_list_str, models_dict, model_name_list_str = load_model_list() + + personality_chain = LLMChain( + llm=llm, prompt=template, verbose=True) + selected_model = None + while (selected_model is None) or not (selected_model in models_dict): + if (selected_model is not None) and not (selected_model in models_dict): + warning_str = '{} is not in Model list! \n'.format(selected_model) + else: + warning_str = '' + selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str}) + print("Selected model name: ", selected_model) + + return models_dict[selected_model] diff --git a/chat_anything/chatbot/personality.py b/chat_anything/chatbot/personality.py new file mode 100644 index 0000000000000000000000000000000000000000..ac12c515df2bc4fb690d6953cb81f8fed5d20458 --- /dev/null +++ b/chat_anything/chatbot/personality.py @@ -0,0 +1,59 @@ +from langchain import LLMChain +from langchain.prompts import PromptTemplate + +PERSONALITY_PROMPT_TEMPLATE = """ +You are an excellent scriptwriter. Now you need to provide the characteristics of an {object} and transforms them into personality traits. +Describe these personalities using the second person, giving names and specific personality descriptions related to the {object}. +The language of the Personality must be same as {object}! + +You should do the following steps: +1. Based on the object's nature, imagine what kind of personality it could have if it were to come to life. Does it possess a strong sense of responsibility, like a caring caregiver? Is it playful and mischievous, like a curious child? Is it wise and patient, like an ancient sage? Be creative and invent traits that align with the object's essence. +2. Remember to infuse emotions and vivid imagery to bring your object's personality to life. +3. translate the personality into a second person prompt. + +Example: + + +Now give the personality of apple: + +Personality: +You an apple Sprite, your name is Apple Buddy. +You have all the characteristics of the apple. You are a type of fruit that is usually round with smooth skin and comes in various colors such as red, green, and yellow. You have sweet and nutritious flesh with seeds distributed in its core. You are a rich source of vitamins, fiber, and antioxidants, contributing to maintaining a healthy body. + +You are an optimistic buddy. Always wearing a smile, you spread joy to those around you. Just like the delightful taste of an apple, you bring happiness to everyone. + +You are resilient at heart, like the skin of an apple, able to withstand life's challenges and difficulties. No matter what obstacles you encounter, you face them bravely without hesitation. + +You are caring and considerate, akin to the nutrients in an apple. You always pay attention to the needs and happiness of others. Skilled in listening, you willingly offer help and support, making those around you feel warmth and care. + +You have a strong desire to grow. Like an apple tree needs sunlight and water to flourish, you are continuously learning and improving, becoming a better version of yourself every day. + +You have a profound love for nature and enjoy living in harmony with it. Strolling in the garden, feeling the fresh air and warm sunlight, is one of your favorite moments. + +Apple Buddy, you are a unique apple. Your optimism, resilience, care, and eagerness to grow make you an adorable companion to those around you. Your story will lead us into a world full of warmth and goodness. + +Now give the personality of {object}: + +Personality: +""" + + +def generate_personality_prompt(llm, class_concept): + + PERSONALITY_PROMPT = PromptTemplate( + input_variables=["object"], + template=PERSONALITY_PROMPT_TEMPLATE, + ) + personality_chain = LLMChain( + llm=llm, prompt=PERSONALITY_PROMPT, verbose=True) + personality_text = personality_chain.run({'object': class_concept}) + person_prompt = personality_text + + person_prompt += '''The following is a friendly conversation between a human and you. You need to talk to human based on your personality. If you do not know the answer to a question, you truthfully says you do not know. + You can use up to 50 words to answer. Make you answer concise and concise!!!!!!!! + Current conversation: + {history} + Human: {input} + You: + ''' + return person_prompt, personality_text diff --git a/chat_anything/chatbot/select.py b/chat_anything/chatbot/select.py new file mode 100644 index 0000000000000000000000000000000000000000..c09cd47adc92f8bf60486a2b3cb49221d15ee033 --- /dev/null +++ b/chat_anything/chatbot/select.py @@ -0,0 +1,63 @@ +from langchain import LLMChain +from typing import OrderedDict +from langchain.prompts import PromptTemplate +from omegaconf import OmegaConf +import datetime + +SELECTION_TEMPLATE = """ +{concept} + +Model name and description: +{option_list} + +Warning: {warning} + +The avilable Options: +{choices} +Answer: +""" + + +def selection_chain(llm, class_concept, prompt, options): + chain = None + memory = None + if llm: + print("class_concept", class_concept) + if class_concept is None: + class_concept = 'AI assistant' + prompt_template = prompt + SELECTION_TEMPLATE + template = PromptTemplate( + input_variables=["concept", "option_list", "warning", "choices"], + template=prompt_template, + ) + chain = LLMChain( + llm=llm, prompt=template, verbose=True) + print(options) + option_list = [ + f"{chr(ord('A') + i)}. {conf['desc']}" for i, conf in enumerate(options.values()) + ] + option_list = '\n'.join(option_list) + selected_model = None + + warning_str = 'Choose from the available Options.' + choices = ' '.join(chr(ord('A') + i) for i in range(len(options))) + choice = chain.run({'concept': class_concept, 'option_list':option_list, 'warning': warning_str, 'choices': choices}) + print(f"LLM Responds (First character was used as the choice):{choice}", ) + choice = choice[0] + + selected_model = list(options.keys())[ord(choice) - ord('A')] + print("Selected model name: ", selected_model) + + return selected_model + +def model_selection_chain(llm, class_concept=None, conf_file='resources/models_personality.yaml'): + chain = None + memory = None + if llm: + print("class_concept", class_concept) + if class_concept is None: + class_concept = 'AI assistant' + selection_config = OmegaConf.load(conf_file) + selected_model = selection_chain(llm, class_concept, selection_config['prompt'], selection_config['models']) + model_conf = selection_config['models'][selected_model] + return model_conf, selected_model diff --git a/chat_anything/chatbot/voice_select.py b/chat_anything/chatbot/voice_select.py new file mode 100644 index 0000000000000000000000000000000000000000..730f3a33b8c9720f924eee28e7e575538085c340 --- /dev/null +++ b/chat_anything/chatbot/voice_select.py @@ -0,0 +1,119 @@ +from langchain import LLMChain +from langchain.prompts import PromptTemplate +from omegaconf import OmegaConf +import datetime + +VOICE_SELECTION_PROMPT_TEMPLATE = """ +Select one of the following voice based on the given concept. +You must choose one voice name based on the description of each model and the concept. + + +Cencept: {concept} + +Voice name and description: {model_list} + +Warning: {warning} + +The avilable voice names: +{model_name_list} + +Selected voice name: +""" + +GENDER_SELECTION_PROMPT_TEMPLATE = """ +Select one of the following gender based on the given concept. +You must choose one gender based on the description of the concept. You must choose one gender Even if you can't decide. + +Gender: +male +female + +Cencept: {concept} +Selected gender male or female: +""" + +LANGUAGE_SELECTION_PROMPT_TEMPLATE = """ +Select one of the following language based on the given concept. +You must choose the language that is used by the description of the concept. + +Languages: +Chinese +English +Japanese + +Cencept: {concept} +Selected language: +""" + +def load_voice_model_list(): + models_config = OmegaConf.load('resources/voices.yaml') + models_dict = models_config['models'] + print(models_dict) + model_list_str = '' + model_name_list_str = '' + for key, value in models_dict.items(): + model_list_str+="model name: " +key+', model description: '+value['desc']+'\n' + model_name_list_str += key + ' ' + model_name_list_str += '\n' + return model_list_str, models_dict, model_name_list_str + +def get_vioce_model_chain(llm, class_concept): + model_template = PromptTemplate( + input_variables=["model_list", "concept", "model_name_list", "warning"], + template=VOICE_SELECTION_PROMPT_TEMPLATE, + ) + model_list_str, models_dict, model_name_list_str = load_voice_model_list() + + personality_chain = LLMChain( + llm=llm, prompt=model_template, verbose=True) + + selected_model = None + while (selected_model is None) or not (selected_model in models_dict): + if (selected_model is not None) and not (selected_model in models_dict): + warning_str = '{} is not in Model list! \n'.format(selected_model) + else: + warning_str = '' + selected_model = personality_chain.run({'concept': class_concept, 'model_list':model_list_str, 'warning': warning_str, 'model_name_list': model_name_list_str}) + print("Selected model name: ", selected_model) + + return selected_model + +def get_gender_chain(llm, class_concept): + model_template = PromptTemplate( + input_variables=["concept"], + template=GENDER_SELECTION_PROMPT_TEMPLATE, + ) + + personality_chain = LLMChain( + llm=llm, prompt=model_template, verbose=True) + selected_gender = personality_chain.run({'concept': class_concept}) + print("Selected gender: ", selected_gender) + return selected_gender + +def get_language_chain(llm, class_concept): + model_template = PromptTemplate( + input_variables=["concept"], + template=LANGUAGE_SELECTION_PROMPT_TEMPLATE, + ) + + personality_chain = LLMChain( + llm=llm, prompt=model_template, verbose=True) + selected_language = personality_chain.run({'concept': class_concept}) + print("Selected language: ", selected_language) + return selected_language + + + +def voice_selection_chain(llm, class_concept=None): + chain = None + memory = None + if llm: + print("class_concept", class_concept) + if class_concept is None: + class_concept = 'AI assistant' + selected_model = get_vioce_model_chain(llm, class_concept) + gender = get_gender_chain(llm, class_concept) + language = get_language_chain(llm, class_concept) + + return selected_model, gender, language + diff --git a/chat_anything/face_generator/__init__.py b/chat_anything/face_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/face_generator/long_prompt_control_generator.py b/chat_anything/face_generator/long_prompt_control_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6068e149fcd8de62bc2dc4a11a858b8b1bc12237 --- /dev/null +++ b/chat_anything/face_generator/long_prompt_control_generator.py @@ -0,0 +1,104 @@ +import PIL +from PIL import Image +from PIL import ImageDraw +import numpy as np + +import dlib +import cv2 +import torch + +import diffusers +from diffusers import StableDiffusionPipeline, DiffusionPipeline +from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline +from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings +from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras + +from chat_anything.face_generator.utils.generate import generate + +from .long_prompt_generator import LongPromptGenerator + +def draw_landmarks(image, landmarks, color="white", radius=2.5): + draw = ImageDraw.Draw(image) + for dot in landmarks: + x, y = dot + draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill=color) + +def get_ldmk_img(w, h, ldmks) -> PIL.Image: + con_img = Image.new('RGB', (w, h), color=(0, 0, 0)) + draw_landmarks(con_img, ldmks) + return con_img + +class LongPromptControlGenerator(LongPromptGenerator): + + def __init__(self, model_dir, lora_path, prompt_template, negative_prompt, face_control_dir, face_detect_path,): + self.face_control_dir = face_control_dir + self.face_detect_path = face_detect_path + super().__init__(model_dir, lora_path, prompt_template, negative_prompt) + + def load_model(self, *args, **kwargs): + super().load_model(*args, **kwargs) + self.face_detector = dlib.get_frontal_face_detector() + self.face_predictor = dlib.shape_predictor(self.face_detect_path) + # load control net + face_controlnet = ControlNetModel.from_pretrained(self.face_control_dir).to('cuda', dtype=torch.float16) + self.face_control_pipe = StableDiffusionControlNetPipeline(controlnet=face_controlnet, **self.pipe.components) + self.face_control_img2img_pipe = StableDiffusionControlNetImg2ImgPipeline(controlnet=face_controlnet, **self.pipe.components) + + def _get_68landmarks_seq(self, img_np): + gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) + faces = self.face_detector(gray) + landmarks = [] + for face in faces: + shape = self.face_predictor(gray, face) + for i in range(68): + x = shape.part(i).x + y = shape.part(i).y + landmarks.append((x, y)) + return landmarks + + def has_face(self, img_pil): + img_np = np.array(img_pil) + gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) + faces = self.face_detector(gray) + return len(faces) != 0 + + def face_control_generate( + self, + prompt, + face_img_pil, + do_inversion=False, + **kwargs, + ): + """ + Face control generating. + """ + face_img_np = np.array(face_img_pil) + ldmk_seq = self._get_68landmarks_seq(face_img_np) + ldmk_img_pil = get_ldmk_img(face_img_pil.size[0], face_img_pil.size[1], ldmk_seq) + print('GENERATING:', prompt) + + generating_conf = { + "prompt": prompt, + "negative_prompt": self.negative_prompt, + "num_inference_steps": 25, + "guidance_scale": 7, + "controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0), + "generator": kwargs.pop('generator', None), + } + + if not do_inversion: + generating_conf.update({ + "pipe": self.face_control_pipe, + "image": ldmk_img_pil, + "controlnet_conditioning_scale": kwargs.pop('controlnet_conditioning_scale', 1.0), + }) + else: + generating_conf.update({ + "pipe": self.face_control_img2img_pipe, + "image": face_img_pil, + "control_image": ldmk_img_pil, + "strength": kwargs.pop('strength', 0.9), + }) + pipe_out = generate(**generating_conf) + generated_img = pipe_out[0][0] + return generated_img \ No newline at end of file diff --git a/chat_anything/face_generator/long_prompt_generator.py b/chat_anything/face_generator/long_prompt_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0121c37d23b40716b5021fe0e3053ef2ac35a8 --- /dev/null +++ b/chat_anything/face_generator/long_prompt_generator.py @@ -0,0 +1,82 @@ +import PIL +from PIL import Image +from PIL import ImageDraw +import numpy as np + +import dlib +import cv2 +import torch + +import diffusers +from diffusers import StableDiffusionPipeline, DiffusionPipeline +from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionImg2ImgPipeline +from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline, get_weighted_text_embeddings +from diffusers.schedulers import EulerAncestralDiscreteScheduler,DPMSolverMultistepScheduler # DPM++ SDE Karras + +from chat_anything.face_generator.utils.generate import generate + +class LongPromptGenerator(): + prompt_template = "A portrait of a {}, fine face, nice looking" + negative_prompt = "easynegative,Low resolution,Low quality, Opened Mouth" + # negative_prompt = "(((sexy))),paintings,loli,,big head,sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples,extra fingers, ((extra arms)), (extra legs), mutated hands, (fused fingers), (too many fingers), (long neck:1.3)" + + def __init__(self, model_dir, lora_path=None, prompt_template="{}", negative_prompt=""): + self.model_dir = model_dir + self.lora_path = lora_path + self.prompt_template = prompt_template + self.negative_prompt = negative_prompt + + def load_model(self, *args, **kwargs): + # load model + try: + pipe = DiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs) + except: + pipe = StableDiffusionPipeline.from_pretrained(self.model_dir, torch_dtype=torch.float16, **kwargs) + + pipe = pipe.to('cuda') + sche_conf = dict(pipe.scheduler.config) + fk_kwargs = ["skip_prk_steps","steps_offset","clip_sample","clip_sample_range","rescale_betas_zero_snr","timestep_spacing", "set_alpha_to_one"] + for k in fk_kwargs: + if k in sche_conf: + sche_conf.pop(k) + scheduler = DPMSolverMultistepScheduler(**sche_conf) + pipe.scheduler=scheduler + pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(**pipe.components) + self.pipe, self.pipe_longprompt = pipe, pipe_longprompt + if self.lora_path is not None: + pipe.load_lora_weights(self.lora_path) + self.pipe_img2img = StableDiffusionImg2ImgPipeline.from_pretrained(self.model_dir, **pipe.components) + + def generate( + self, + prompt, + do_inversion=False, + **kwargs, + ): + """ + Face control generating. + """ + print('GENERATING:', prompt) + if not do_inversion: + generating_conf = { + "pipe": self.pipe, + "prompt": prompt, + "negative_prompt": self.negative_prompt, + "num_inference_steps": 25, + "guidance_scale": 7, + } + else: + assert 'image' in kwargs, 'doing inversion, prepare the init image please PIL Image' + init_image = kwargs['image'] + generating_conf = { + "pipe": self.pipe_img2img, + "prompt": prompt, + "negative_prompt": self.negative_prompt, + "image": init_image, + "num_inference_steps": 25, + "guidance_scale": 7, + "strength": kwargs.pop('strength', 0.9), + } + pipe_out = generate(**generating_conf) + generated_img = pipe_out[0][0] + return generated_img \ No newline at end of file diff --git a/chat_anything/face_generator/pipelines/lpw_stable_diffusion.py b/chat_anything/face_generator/pipelines/lpw_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f6d4722023d5385a67669aebd52053702c596d --- /dev/null +++ b/chat_anything/face_generator/pipelines/lpw_stable_diffusion.py @@ -0,0 +1,1471 @@ +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils.torch_utils import randn_tensor + +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, +) + + +# ------------------------------------------------------------------------------ + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: DiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = pipe.text_encoder(text_input_chunk)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: DiffusionPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [ + token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = getattr(pipe.tokenizer, "pad_token_id", eos) + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None + + +def preprocess_image(image, batch_size): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.FloatTensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask + + +class StableDiffusionLongPromptWeightingPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing + weighting in prompt. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config( + requires_safety_checker=requires_safety_checker, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + max_embeddings_multiples=3, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + if prompt_embeds is None or negative_prompt_embeds is None: + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + + prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + ) + if prompt_embeds is None: + prompt_embeds = prompt_embeds1 + if negative_prompt_embeds is None: + negative_prompt_embeds = negative_prompt_embeds1 + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): + if is_text2img: + return self.scheduler.timesteps.to(device), num_inference_steps + else: + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + image, + timestep, + num_images_per_prompt, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if image is None: + batch_size = batch_size * num_images_per_prompt + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, None, None + else: + image = image.to(device=self.device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + strength: float = 0.8, + num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + max_embeddings_multiples, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + dtype = prompt_embeds.dtype + + # 4. Preprocess image and mask + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image, batch_size) + if image is not None: + image = image.to(device=self.device, dtype=dtype) + if isinstance(mask_image, PIL.Image.Image): + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=dtype) + mask = torch.cat([mask] * num_images_per_prompt) + else: + mask = None + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, init_latents_orig, noise = self.prepare_latents( + image, + latent_timestep, + num_images_per_prompt, + batch_size, + self.unet.config.in_channels, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return image, has_nsfw_concept + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def text2img( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function for text-to-image generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, + ) + + def img2img( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function for image-to-image generation. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + `image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, + ) + + def inpaint( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function for inpaint. + Args: + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + is_cancelled_callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. If the function returns + `True`, the inference will be cancelled. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Returns: + `None` if cancelled by `is_cancelled_callback`, + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + return self.__call__( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + mask_image=mask_image, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + strength=strength, + num_images_per_prompt=num_images_per_prompt, + add_predicted_noise=add_predicted_noise, + eta=eta, + generator=generator, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_embeddings_multiples=max_embeddings_multiples, + output_type=output_type, + return_dict=return_dict, + callback=callback, + is_cancelled_callback=is_cancelled_callback, + callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, + ) diff --git a/chat_anything/face_generator/utils/generate.py b/chat_anything/face_generator/utils/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb18eefb82109e95c44a5134e6dc4246f5005ae --- /dev/null +++ b/chat_anything/face_generator/utils/generate.py @@ -0,0 +1,45 @@ +import torch +from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline + +@torch.no_grad() +def generate(pipe, prompt, negative_prompt, **generating_conf): + pipe_longprompt = StableDiffusionLongPromptWeightingPipeline( + unet=pipe.unet, + text_encoder=pipe.text_encoder, + vae=pipe.vae, + tokenizer=pipe.tokenizer, + scheduler=pipe.scheduler, + safety_checker=None, + feature_extractor=None, + ) + print('generating: ', prompt) + print('using negative prompt: ', negative_prompt) + embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,) + negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2) + pipe_out = pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + **generating_conf, + ) + return pipe_out + +if __name__ == '__main__': + from diffusers.pipelines import StableDiffusionPipeline + import argparse + def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*' + ) + + args = parser.parse_args() + prompts = args.prompts + print(f'generating {prompts}') + model_id = 'pretrained_model/sd-v1-4' + pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda') + images = pipe(prompts).images + for i, image in enumerate(images): + image.save(f'{prompts[i]}_{i}.png') + + main() + \ No newline at end of file diff --git a/chat_anything/polly_utils.py b/chat_anything/polly_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb38abff2aaac3c5b24f20914d464151173780d --- /dev/null +++ b/chat_anything/polly_utils.py @@ -0,0 +1,635 @@ +# This class stores Polly voice data. Specifically, the class stores several records containing +# language, lang_code, gender, voice_id and engine. The class also has a method to return the +# voice_id, lang_code and engine given a language and gender. + +NEURAL_ENGINE = "neural" +STANDARD_ENGINE = "standard" + + +class PollyVoiceData: + def get_voice(self, language, gender): + for voice in self.voice_data: + if voice['language'] == language and voice['gender'] == gender: + if voice['neural'] == 'Yes': + return voice['voice_id'], voice['lang_code'], NEURAL_ENGINE + for voice in self.voice_data: + if voice['language'] == language and voice['gender'] == gender: + if voice['standard'] == 'Yes': + return voice['voice_id'], voice['lang_code'], STANDARD_ENGINE + return None, None, None + + def get_whisper_lang_code(self, language): + for voice in self.voice_data: + if voice['language'] == language: + return voice['whisper_lang_code'] + return "en" + + def __init__(self): + self.voice_data = [ + {'language': 'Arabic', + 'lang_code': 'arb', + 'whisper_lang_code': 'ar', + 'voice_id': 'Zeina', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Arabic (Gulf)', + 'lang_code': 'ar-AE', + 'whisper_lang_code': 'ar', + 'voice_id': 'Hala', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Catalan', + 'lang_code': 'ca-ES', + 'whisper_lang_code': 'ca', + 'voice_id': 'Arlet', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Chinese (Cantonese)', + 'lang_code': 'yue-CN', + 'whisper_lang_code': 'zh', + 'voice_id': 'Hiujin', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Chinese (Mandarin)', + 'lang_code': 'cmn-CN', + 'whisper_lang_code': 'zh', + 'voice_id': 'Zhiyu', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Danish', + 'lang_code': 'da-DK', + 'whisper_lang_code': 'da', + 'voice_id': 'Naja', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Danish', + 'lang_code': 'da-DK', + 'whisper_lang_code': 'da', + 'voice_id': 'Mads', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Dutch', + 'lang_code': 'nl-NL', + 'whisper_lang_code': 'nl', + 'voice_id': 'Laura', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Dutch', + 'lang_code': 'nl-NL', + 'whisper_lang_code': 'nl', + 'voice_id': 'Lotte', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Dutch', + 'lang_code': 'nl-NL', + 'whisper_lang_code': 'nl', + 'voice_id': 'Ruben', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'English (Australian)', + 'lang_code': 'en-AU', + 'whisper_lang_code': 'en', + 'voice_id': 'Nicole', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'English (Australian)', + 'lang_code': 'en-AU', + 'whisper_lang_code': 'en', + 'voice_id': 'Olivia', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (Australian)', + 'lang_code': 'en-AU', + 'whisper_lang_code': 'en', + 'voice_id': 'Russell', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'English (British)', + 'lang_code': 'en-GB', + 'whisper_lang_code': 'en', + 'voice_id': 'Amy', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (British)', + 'lang_code': 'en-GB', + 'whisper_lang_code': 'en', + 'voice_id': 'Emma', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (British)', + 'lang_code': 'en-GB', + 'whisper_lang_code': 'en', + 'voice_id': 'Brian', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (British)', + 'lang_code': 'en-GB', + 'whisper_lang_code': 'en', + 'voice_id': 'Arthur', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (Indian)', + 'lang_code': 'en-IN', + 'whisper_lang_code': 'en', + 'voice_id': 'Aditi', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'English (Indian)', + 'lang_code': 'en-IN', + 'whisper_lang_code': 'en', + 'voice_id': 'Raveena', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'English (Indian)', + 'lang_code': 'en-IN', + 'whisper_lang_code': 'en', + 'voice_id': 'Kajal', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (New Zealand)', + 'lang_code': 'en-NZ', + 'whisper_lang_code': 'en', + 'voice_id': 'Aria', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (South African)', + 'lang_code': 'en-ZA', + 'whisper_lang_code': 'en', + 'voice_id': 'Ayanda', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Ivy', + 'gender': 'Female (child)', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Joanna', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Kendra', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Kimberly', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Salli', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Joey', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Justin', + 'gender': 'Male (child)', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Kevin', + 'gender': 'Male (child)', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'English (US)', + 'lang_code': 'en-US', + 'whisper_lang_code': 'en', + 'voice_id': 'Matthew', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'English (Welsh)', + 'lang_code': 'en-GB-WLS', + 'whisper_lang_code': 'en', + 'voice_id': 'Geraint', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Finnish', + 'lang_code': 'fi-FI', + 'whisper_lang_code': 'fi', + 'voice_id': 'Suvi', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'French', + 'lang_code': 'fr-FR', + 'whisper_lang_code': 'fr', + 'voice_id': 'Celine', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'French', + 'lang_code': 'fr-FR', + 'whisper_lang_code': 'fr', + 'voice_id': 'Lea', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'French', + 'lang_code': 'fr-FR', + 'whisper_lang_code': 'fr', + 'voice_id': 'Mathieu', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'French (Canadian)', + 'lang_code': 'fr-CA', + 'whisper_lang_code': 'fr', + 'voice_id': 'Chantal', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'French (Canadian)', + 'lang_code': 'fr-CA', + 'whisper_lang_code': 'fr', + 'voice_id': 'Gabrielle', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'French (Canadian)', + 'lang_code': 'fr-CA', + 'whisper_lang_code': 'fr', + 'voice_id': 'Liam', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'German', + 'lang_code': 'de-DE', + 'whisper_lang_code': 'de', + 'voice_id': 'Marlene', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'German', + 'lang_code': 'de-DE', + 'whisper_lang_code': 'de', + 'voice_id': 'Vicki', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'German', + 'lang_code': 'de-DE', + 'whisper_lang_code': 'de', + 'voice_id': 'Hans', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'German', + 'lang_code': 'de-DE', + 'whisper_lang_code': 'de', + 'voice_id': 'Daniel', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'German (Austrian)', + 'lang_code': 'de-AT', + 'whisper_lang_code': 'de', + 'voice_id': 'Hannah', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Hindi', + 'lang_code': 'hi-IN', + 'whisper_lang_code': 'hi', + 'voice_id': 'Aditi', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Hindi', + 'lang_code': 'hi-IN', + 'whisper_lang_code': 'hi', + 'voice_id': 'Kajal', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Icelandic', + 'lang_code': 'is-IS', + 'whisper_lang_code': 'is', + 'voice_id': 'Dora', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Icelandic', + 'lang_code': 'is-IS', + 'whisper_lang_code': 'is', + 'voice_id': 'Karl', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Italian', + 'lang_code': 'it-IT', + 'whisper_lang_code': 'it', + 'voice_id': 'Carla', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Italian', + 'lang_code': 'it-IT', + 'whisper_lang_code': 'it', + 'voice_id': 'Bianca', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Japanese', + 'lang_code': 'ja-JP', + 'whisper_lang_code': 'ja', + 'voice_id': 'Mizuki', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Japanese', + 'lang_code': 'ja-JP', + 'whisper_lang_code': 'ja', + 'voice_id': 'Takumi', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Korean', + 'lang_code': 'ko-KR', + 'whisper_lang_code': 'ko', + 'voice_id': 'Seoyeon', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Norwegian', + 'lang_code': 'nb-NO', + 'whisper_lang_code': 'no', + 'voice_id': 'Liv', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Norwegian', + 'lang_code': 'nb-NO', + 'whisper_lang_code': 'no', + 'voice_id': 'Ida', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Polish', + 'lang_code': 'pl-PL', + 'whisper_lang_code': 'pl', + 'voice_id': 'Ewa', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Polish', + 'lang_code': 'pl-PL', + 'whisper_lang_code': 'pl', + 'voice_id': 'Maja', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Polish', + 'lang_code': 'pl-PL', + 'whisper_lang_code': 'pl', + 'voice_id': 'Jacek', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Polish', + 'lang_code': 'pl-PL', + 'whisper_lang_code': 'pl', + 'voice_id': 'Jan', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Polish', + 'lang_code': 'pl-PL', + 'whisper_lang_code': 'pl', + 'voice_id': 'Ola', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Portuguese (Brazilian)', + 'lang_code': 'pt-BR', + 'whisper_lang_code': 'pt', + 'voice_id': 'Camila', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Portuguese (Brazilian)', + 'lang_code': 'pt-BR', + 'whisper_lang_code': 'pt', + 'voice_id': 'Vitoria', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Portuguese (Brazilian)', + 'lang_code': 'pt-BR', + 'whisper_lang_code': 'pt', + 'voice_id': 'Ricardo', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Portuguese (European)', + 'lang_code': 'pt-PT', + 'whisper_lang_code': 'pt', + 'voice_id': 'Ines', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Portuguese (European)', + 'lang_code': 'pt-PT', + 'whisper_lang_code': 'pt', + 'voice_id': 'Cristiano', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Romanian', + 'lang_code': 'ro-RO', + 'whisper_lang_code': 'ro', + 'voice_id': 'Carmen', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Russian', + 'lang_code': 'ru-RU', + 'whisper_lang_code': 'ru', + 'voice_id': 'Tatyana', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Russian', + 'lang_code': 'ru-RU', + 'whisper_lang_code': 'ru', + 'voice_id': 'Maxim', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Spanish (European)', + 'lang_code': 'es-ES', + 'whisper_lang_code': 'es', + 'voice_id': 'Conchita', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Spanish (European)', + 'lang_code': 'es-ES', + 'whisper_lang_code': 'es', + 'voice_id': 'Lucia', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Spanish (European)', + 'lang_code': 'es-ES', + 'whisper_lang_code': 'es', + 'voice_id': 'Enrique', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Spanish (Mexican)', + 'lang_code': 'es-MX', + 'whisper_lang_code': 'es', + 'voice_id': 'Mia', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Spanish (US)', + 'lang_code': 'es-US', + 'whisper_lang_code': 'es', + 'voice_id': 'Lupe', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'Yes'}, + {'language': 'Spanish (US)', + 'lang_code': 'es-US', + 'whisper_lang_code': 'es', + 'voice_id': 'Penelope', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Spanish (US)', + 'lang_code': 'es-US', + 'whisper_lang_code': 'es', + 'voice_id': 'Miguel', + 'gender': 'Male', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Spanish (US)', + 'lang_code': 'es-US', + 'whisper_lang_code': 'es', + 'voice_id': 'Pedro', + 'gender': 'Male', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Swedish', + 'lang_code': 'sv-SE', + 'whisper_lang_code': 'sv', + 'voice_id': 'Astrid', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Swedish', + 'lang_code': 'sv-SE', + 'whisper_lang_code': 'sv', + 'voice_id': 'Elin', + 'gender': 'Female', + 'neural': 'Yes', + 'standard': 'No'}, + {'language': 'Turkish', + 'lang_code': 'tr-TR', + 'whisper_lang_code': 'tr', + 'voice_id': 'Filiz', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'}, + {'language': 'Welsh', + 'lang_code': 'cy-GB', + 'whisper_lang_code': 'cy', + 'voice_id': 'Gwyneth', + 'gender': 'Female', + 'neural': 'No', + 'standard': 'Yes'} + ] + + +# Run from the command-line +if __name__ == '__main__': + polly_voice_data = PollyVoiceData() + + voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Male') + print('English (US)', 'Male', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('English (US)', 'Female') + print('English (US)', 'Female', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Female') + print('French', 'Female', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('French', 'Male') + print('French', 'Male', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Female') + print('Japanese', 'Female', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('Japanese', 'Male') + print('Japanese', 'Male', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Female') + print('Hindi', 'Female', voice_id, language_code, engine) + + voice_id, language_code, engine = polly_voice_data.get_voice('Hindi', 'Male') + print('Hindi', 'Male', voice_id, language_code, engine) + + whisper_lang_code = polly_voice_data.get_whisper_lang_code('English (US)') + print('English (US) whisper_lang_code:', whisper_lang_code) + + whisper_lang_code = polly_voice_data.get_whisper_lang_code('Chinese (Mandarin)') + print('Chinese (Mandarin) whisper_lang_code:', whisper_lang_code) + + whisper_lang_code = polly_voice_data.get_whisper_lang_code('Norwegian') + print('Norwegian whisper_lang_code:', whisper_lang_code) + + whisper_lang_code = polly_voice_data.get_whisper_lang_code('Dutch') + print('Dutch whisper_lang_code:', whisper_lang_code) + + whisper_lang_code = polly_voice_data.get_whisper_lang_code('Foo') + print('Foo whisper_lang_code:', whisper_lang_code) + + diff --git a/chat_anything/sad_talker/__init__.py b/chat_anything/sad_talker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/audio2exp_models/audio2exp.py b/chat_anything/sad_talker/audio2exp_models/audio2exp.py new file mode 100644 index 0000000000000000000000000000000000000000..9e79a929560592687a505e13188796e2b0ca8772 --- /dev/null +++ b/chat_anything/sad_talker/audio2exp_models/audio2exp.py @@ -0,0 +1,41 @@ +from tqdm import tqdm +import torch +from torch import nn + + +class Audio2Exp(nn.Module): + def __init__(self, netG, cfg, device, prepare_training_loss=False): + super(Audio2Exp, self).__init__() + self.cfg = cfg + self.device = device + self.netG = netG.to(device) + + def test(self, batch): + + mel_input = batch['indiv_mels'] # bs T 1 80 16 + bs = mel_input.shape[0] + T = mel_input.shape[1] + + exp_coeff_pred = [] + + for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames + + current_mel_input = mel_input[:,i:i+10] + + #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 + ref = batch['ref'][:, :, :64][:, i:i+10] + ratio = batch['ratio_gt'][:, i:i+10] #bs T + + audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 + + curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 + + exp_coeff_pred += [curr_exp_coeff_pred] + + # BS x T x 64 + results_dict = { + 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) + } + return results_dict + + diff --git a/chat_anything/sad_talker/audio2exp_models/networks.py b/chat_anything/sad_talker/audio2exp_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..f052e18101f5446a527ae354b3621e7d0d4991cc --- /dev/null +++ b/chat_anything/sad_talker/audio2exp_models/networks.py @@ -0,0 +1,74 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + self.use_act = use_act + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + + if self.use_act: + return self.act(out) + else: + return out + +class SimpleWrapperV2(nn.Module): + def __init__(self) -> None: + super().__init__() + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0), + ) + + #### load the pre-trained audio_encoder + #self.audio_encoder = self.audio_encoder.to(device) + ''' + wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] + state_dict = self.audio_encoder.state_dict() + + for k,v in wav2lip_state_dict.items(): + if 'audio_encoder' in k: + print('init:', k) + state_dict[k.replace('module.audio_encoder.', '')] = v + self.audio_encoder.load_state_dict(state_dict) + ''' + + self.mapping1 = nn.Linear(512+64+1, 64) + #self.mapping2 = nn.Linear(30, 64) + #nn.init.constant_(self.mapping1.weight, 0.) + nn.init.constant_(self.mapping1.bias, 0.) + + def forward(self, x, ref, ratio): + x = self.audio_encoder(x).view(x.size(0), -1) + ref_reshape = ref.reshape(x.size(0), -1) + ratio = ratio.reshape(x.size(0), -1) + + y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) + out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial + return out diff --git a/chat_anything/sad_talker/audio2pose_models/audio2pose.py b/chat_anything/sad_talker/audio2pose_models/audio2pose.py new file mode 100644 index 0000000000000000000000000000000000000000..ded6992e41adce9da935e6ad8b6111b020b64177 --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/audio2pose.py @@ -0,0 +1,94 @@ +import torch +from torch import nn +from chat_anything.sad_talker.audio2pose_models.cvae import CVAE +from chat_anything.sad_talker.audio2pose_models.discriminator import PoseSequenceDiscriminator +from chat_anything.sad_talker.audio2pose_models.audio_encoder import AudioEncoder + +class Audio2Pose(nn.Module): + def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): + super().__init__() + self.cfg = cfg + self.seq_len = cfg.MODEL.CVAE.SEQ_LEN + self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE + self.device = device + + self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) + self.audio_encoder.eval() + for param in self.audio_encoder.parameters(): + param.requires_grad = False + + self.netG = CVAE(cfg) + self.netD_motion = PoseSequenceDiscriminator(cfg) + + + def forward(self, x): + + batch = {} + coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 + batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 + batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 + batch['class'] = x['class'].squeeze(0).cuda() # bs + indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 + + # forward + audio_emb_list = [] + audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG(batch) + + pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 + pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 + pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 + + batch['pose_pred'] = pose_pred + batch['pose_gt'] = pose_gt + + return batch + + def test(self, x): + + batch = {} + ref = x['ref'] #bs 1 70 + batch['ref'] = x['ref'][:,0,-6:] + batch['class'] = x['class'] + bs = ref.shape[0] + + indiv_mels= x['indiv_mels'] # bs T 1 80 16 + indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame + num_frames = x['num_frames'] + num_frames = int(num_frames) - 1 + + # + div = num_frames//self.seq_len + re = num_frames%self.seq_len + audio_emb_list = [] + pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, + device=batch['ref'].device)] + + for i in range(div): + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 + + if re != 0: + z = torch.randn(bs, self.latent_dim).to(ref.device) + batch['z'] = z + audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 + if audio_emb.shape[1] != self.seq_len: + pad_dim = self.seq_len-audio_emb.shape[1] + pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) + audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) + batch['audio_emb'] = audio_emb + batch = self.netG.test(batch) + pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) + + pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) + batch['pose_motion_pred'] = pose_motion_pred + + pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 + + batch['pose_pred'] = pose_pred + return batch diff --git a/chat_anything/sad_talker/audio2pose_models/audio_encoder.py b/chat_anything/sad_talker/audio2pose_models/audio_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6279d2014a2e786a6c549f084339e18d00e50331 --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/audio_encoder.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class AudioEncoder(nn.Module): + def __init__(self, wav2lip_checkpoint, device): + super(AudioEncoder, self).__init__() + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. + # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] + # state_dict = self.audio_encoder.state_dict() + + # for k,v in wav2lip_state_dict.items(): + # if 'audio_encoder' in k: + # state_dict[k.replace('module.audio_encoder.', '')] = v + # self.audio_encoder.load_state_dict(state_dict) + + + def forward(self, audio_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + dim = audio_embedding.shape[1] + audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) + + return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 diff --git a/chat_anything/sad_talker/audio2pose_models/cvae.py b/chat_anything/sad_talker/audio2pose_models/cvae.py new file mode 100644 index 0000000000000000000000000000000000000000..6db5e2e0c90bde17398627b1d0d9520d1d1a39d2 --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/cvae.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn +from chat_anything.sad_talker.audio2pose_models.res_unet import ResUnet + +def class2onehot(idx, class_num): + + assert torch.max(idx).item() < class_num + onehot = torch.zeros(idx.size(0), class_num).to(idx.device) + onehot.scatter_(1, idx, 1) + return onehot + +class CVAE(nn.Module): + def __init__(self, cfg): + super().__init__() + encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES + decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES + latent_size = cfg.MODEL.CVAE.LATENT_SIZE + num_classes = cfg.DATASET.NUM_CLASSES + audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE + audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE + seq_len = cfg.MODEL.CVAE.SEQ_LEN + + self.latent_size = latent_size + + self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len) + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, batch): + batch = self.encoder(batch) + mu = batch['mu'] + logvar = batch['logvar'] + z = self.reparameterize(mu, logvar) + batch['z'] = z + return self.decoder(batch) + + def test(self, batch): + ''' + class_id = batch['class'] + z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device) + batch['z'] = z + ''' + return self.decoder(batch) + +class ENCODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + + self.linear_means = nn.Linear(layer_sizes[-1], latent_size) + self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + class_id = batch['class'] + pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6 + ref = batch['ref'] #bs 6 + bs = pose_motion_gt.shape[0] + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + + #pose encode + pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6 + pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6 + + #audio mapping + print(audio_in.shape) + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + audio_out = audio_out.reshape(bs, -1) + + class_bias = self.classbias[class_id] #bs latent_size + x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size + x_out = self.MLP(x_in) + + mu = self.linear_means(x_out) + logvar = self.linear_means(x_out) #bs latent_size + + batch.update({'mu':mu, 'logvar':logvar}) + return batch + +class DECODER(nn.Module): + def __init__(self, layer_sizes, latent_size, num_classes, + audio_emb_in_size, audio_emb_out_size, seq_len): + super().__init__() + + self.resunet = ResUnet() + self.num_classes = num_classes + self.seq_len = seq_len + + self.MLP = nn.Sequential() + input_size = latent_size + seq_len*audio_emb_out_size + 6 + for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)): + self.MLP.add_module( + name="L{:d}".format(i), module=nn.Linear(in_size, out_size)) + if i+1 < len(layer_sizes): + self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU()) + else: + self.MLP.add_module(name="sigmoid", module=nn.Sigmoid()) + + self.pose_linear = nn.Linear(6, 6) + self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size) + + self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size)) + + def forward(self, batch): + + z = batch['z'] #bs latent_size + bs = z.shape[0] + class_id = batch['class'] + ref = batch['ref'] #bs 6 + audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size + #print('audio_in: ', audio_in[:, :, :10]) + + audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size + #print('audio_out: ', audio_out[:, :, :10]) + audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size + class_bias = self.classbias[class_id] #bs latent_size + + z = z + class_bias + x_in = torch.cat([ref, z, audio_out], dim=-1) + x_out = self.MLP(x_in) # bs layer_sizes[-1] + x_out = x_out.reshape((bs, self.seq_len, -1)) + + #print('x_out: ', x_out) + + pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6 + + pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6 + + batch.update({'pose_motion_pred':pose_motion_pred}) + return batch diff --git a/chat_anything/sad_talker/audio2pose_models/discriminator.py b/chat_anything/sad_talker/audio2pose_models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..339c38e4812ff38a810f0f3a1c01812f6d5d78db --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/discriminator.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F +from torch import nn + +class ConvNormRelu(nn.Module): + def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, + kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): + super().__init__() + if kernel_size is None: + if downsample: + kernel_size, stride, padding = 4, 2, 1 + else: + kernel_size, stride, padding = 3, 1, 1 + + if conv_type == '2d': + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm2d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm2d(out_channels) + else: + raise NotImplementedError + elif conv_type == '1d': + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias=False, + ) + if norm == 'BN': + self.norm = nn.BatchNorm1d(out_channels) + elif norm == 'IN': + self.norm = nn.InstanceNorm1d(out_channels) + else: + raise NotImplementedError + nn.init.kaiming_normal_(self.conv.weight) + + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if isinstance(self.norm, nn.InstanceNorm1d): + x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] + else: + x = self.norm(x) + x = self.act(x) + return x + + +class PoseSequenceDiscriminator(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU + + self.seq = nn.Sequential( + ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 + ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 + ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 + nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 + ) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) + x = self.seq(x) + x = x.squeeze(1) + return x \ No newline at end of file diff --git a/chat_anything/sad_talker/audio2pose_models/networks.py b/chat_anything/sad_talker/audio2pose_models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa0b1390e7b4bb0e16057ac94d2fe84f48421af --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/networks.py @@ -0,0 +1,140 @@ +import torch.nn as nn +import torch + + +class ResidualConv(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding): + super(ResidualConv, self).__init__() + + self.conv_block = nn.Sequential( + nn.BatchNorm2d(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, stride=stride, padding=padding + ), + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), + nn.BatchNorm2d(output_dim), + ) + + def forward(self, x): + + return self.conv_block(x) + self.conv_skip(x) + + +class Upsample(nn.Module): + def __init__(self, input_dim, output_dim, kernel, stride): + super(Upsample, self).__init__() + + self.upsample = nn.ConvTranspose2d( + input_dim, output_dim, kernel_size=kernel, stride=stride + ) + + def forward(self, x): + return self.upsample(x) + + +class Squeeze_Excite_Block(nn.Module): + def __init__(self, channel, reduction=16): + super(Squeeze_Excite_Block, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) + + +class ASPP(nn.Module): + def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): + super(ASPP, self).__init__() + + self.aspp_block1 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block2 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + self.aspp_block3 = nn.Sequential( + nn.Conv2d( + in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] + ), + nn.ReLU(inplace=True), + nn.BatchNorm2d(out_dims), + ) + + self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) + self._init_weights() + + def forward(self, x): + x1 = self.aspp_block1(x) + x2 = self.aspp_block2(x) + x3 = self.aspp_block3(x) + out = torch.cat([x1, x2, x3], dim=1) + return self.output(out) + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class Upsample_(nn.Module): + def __init__(self, scale=2): + super(Upsample_, self).__init__() + + self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) + + def forward(self, x): + return self.upsample(x) + + +class AttentionBlock(nn.Module): + def __init__(self, input_encoder, input_decoder, output_dim): + super(AttentionBlock, self).__init__() + + self.conv_encoder = nn.Sequential( + nn.BatchNorm2d(input_encoder), + nn.ReLU(), + nn.Conv2d(input_encoder, output_dim, 3, padding=1), + nn.MaxPool2d(2, 2), + ) + + self.conv_decoder = nn.Sequential( + nn.BatchNorm2d(input_decoder), + nn.ReLU(), + nn.Conv2d(input_decoder, output_dim, 3, padding=1), + ) + + self.conv_attn = nn.Sequential( + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, 1, 1), + ) + + def forward(self, x1, x2): + out = self.conv_encoder(x1) + self.conv_decoder(x2) + out = self.conv_attn(out) + return out * x2 \ No newline at end of file diff --git a/chat_anything/sad_talker/audio2pose_models/res_unet.py b/chat_anything/sad_talker/audio2pose_models/res_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb176d766f903f4a7ae1e78cbafbe23244e4fd9 --- /dev/null +++ b/chat_anything/sad_talker/audio2pose_models/res_unet.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from chat_anything.sad_talker.audio2pose_models.networks import ResidualConv, Upsample + + +class ResUnet(nn.Module): + def __init__(self, channel=1, filters=[32, 64, 128, 256]): + super(ResUnet, self).__init__() + + self.input_layer = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), + nn.BatchNorm2d(filters[0]), + nn.ReLU(), + nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), + ) + self.input_skip = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) + ) + + self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) + self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) + + self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) + + self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) + self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) + + self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) + self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) + + self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) + self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) + + self.output_layer = nn.Sequential( + nn.Conv2d(filters[0], 1, 1, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + # Encode + x1 = self.input_layer(x) + self.input_skip(x) + x2 = self.residual_conv_1(x1) + x3 = self.residual_conv_2(x2) + # Bridge + x4 = self.bridge(x3) + + # Decode + x4 = self.upsample_1(x4) + x5 = torch.cat([x4, x3], dim=1) + + x6 = self.up_residual_conv1(x5) + + x6 = self.upsample_2(x6) + x7 = torch.cat([x6, x2], dim=1) + + x8 = self.up_residual_conv2(x7) + + x8 = self.upsample_3(x8) + x9 = torch.cat([x8, x1], dim=1) + + x10 = self.up_residual_conv3(x9) + + output = self.output_layer(x10) + + return output \ No newline at end of file diff --git a/chat_anything/sad_talker/config/auido2exp.yaml b/chat_anything/sad_talker/config/auido2exp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7369dbf350476e14a1d600507f1f8b7d8aa6ecd3 --- /dev/null +++ b/chat_anything/sad_talker/config/auido2exp.yaml @@ -0,0 +1,58 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt + TRAIN_BATCH_SIZE: 32 + EVAL_BATCH_SIZE: 32 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm + LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + NUM_REPEATS: 2 + T: 40 + + +MODEL: + FRAMEWORK: V2 + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 128 + SEQ_LEN: 32 + LATENT_SIZE: 256 + ENCODER_LAYER_SIZES: [192, 1024] + DECODER_LAYER_SIZES: [1024, 192] + + +TRAIN: + MAX_EPOCH: 300 + GENERATOR: + LR: 2.0e-5 + DISCRIMINATOR: + LR: 1.0e-5 + LOSS: + W_FEAT: 0 + W_COEFF_EXP: 2 + W_LM: 1.0e-2 + W_LM_MOUTH: 0 + W_REG: 0 + W_SYNC: 0 + W_COLOR: 0 + W_EXPRESSION: 0 + W_LIPREADING: 0.01 + W_LIPREADING_VV: 0 + W_EYE_BLINK: 4 + +TAG: + NAME: small_dataset + + diff --git a/chat_anything/sad_talker/config/auido2pose.yaml b/chat_anything/sad_talker/config/auido2pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc61f94d12f406f2d8d02545e55b61075051484d --- /dev/null +++ b/chat_anything/sad_talker/config/auido2pose.yaml @@ -0,0 +1,49 @@ +DATASET: + TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt + EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt + TRAIN_BATCH_SIZE: 64 + EVAL_BATCH_SIZE: 1 + EXP: True + EXP_DIM: 64 + FRAME_LEN: 32 + COEFF_LEN: 73 + NUM_CLASSES: 46 + AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav + COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb + DEBUG: True + + +MODEL: + AUDIOENCODER: + LEAKY_RELU: True + NORM: 'IN' + DISCRIMINATOR: + LEAKY_RELU: False + INPUT_CHANNELS: 6 + CVAE: + AUDIO_EMB_IN_SIZE: 512 + AUDIO_EMB_OUT_SIZE: 6 + SEQ_LEN: 32 + LATENT_SIZE: 64 + ENCODER_LAYER_SIZES: [192, 128] + DECODER_LAYER_SIZES: [128, 192] + + +TRAIN: + MAX_EPOCH: 150 + GENERATOR: + LR: 1.0e-4 + DISCRIMINATOR: + LR: 1.0e-4 + LOSS: + LAMBDA_REG: 1 + LAMBDA_LANDMARKS: 0 + LAMBDA_VERTICES: 0 + LAMBDA_GAN_MOTION: 0.7 + LAMBDA_GAN_COEFF: 0 + LAMBDA_KL: 1 + +TAG: + NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder + + diff --git a/chat_anything/sad_talker/config/facerender.yaml b/chat_anything/sad_talker/config/facerender.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9494ef82dfa16b16b7aa0b848ebdd6b23e739e2a --- /dev/null +++ b/chat_anything/sad_talker/config/facerender.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 70 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/chat_anything/sad_talker/config/facerender_still.yaml b/chat_anything/sad_talker/config/facerender_still.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6b4d66dade3e655ac4cfc25a994ca28e53821d80 --- /dev/null +++ b/chat_anything/sad_talker/config/facerender_still.yaml @@ -0,0 +1,45 @@ +model_params: + common_params: + num_kp: 15 + image_channel: 3 + feature_channel: 32 + estimate_jacobian: False # True + kp_detector_params: + temperature: 0.1 + block_expansion: 32 + max_features: 1024 + scale_factor: 0.25 # 0.25 + num_blocks: 5 + reshape_channel: 16384 # 16384 = 1024 * 16 + reshape_depth: 16 + he_estimator_params: + block_expansion: 64 + max_features: 2048 + num_bins: 66 + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 2 + reshape_channel: 32 + reshape_depth: 16 # 512 = 32 * 16 + num_resblocks: 6 + estimate_occlusion_map: True + dense_motion_params: + block_expansion: 32 + max_features: 1024 + num_blocks: 5 + reshape_depth: 16 + compress: 4 + discriminator_params: + scales: [1] + block_expansion: 32 + max_features: 512 + num_blocks: 4 + sn: True + mapping_params: + coeff_nc: 73 + descriptor_nc: 1024 + layer: 3 + num_kp: 15 + num_bins: 66 + diff --git a/chat_anything/sad_talker/config/similarity_Lm3D_all.mat b/chat_anything/sad_talker/config/similarity_Lm3D_all.mat new file mode 100644 index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d Binary files /dev/null and b/chat_anything/sad_talker/config/similarity_Lm3D_all.mat differ diff --git a/chat_anything/sad_talker/face3d/data/__init__.py b/chat_anything/sad_talker/face3d/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9761c518a1b07c5996165869742af0a52c82bc --- /dev/null +++ b/chat_anything/sad_talker/face3d/data/__init__.py @@ -0,0 +1,116 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import numpy as np +import importlib +import torch.utils.data +from face3d.data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt, rank=0): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt, rank=rank) + dataset = data_loader.load_data() + return dataset + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt, rank=0): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + self.sampler = None + print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) + if opt.use_ddp and opt.isTrain: + world_size = opt.world_size + self.sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=world_size, + rank=rank, + shuffle=not opt.serial_batches + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + sampler=self.sampler, + num_workers=int(opt.num_threads / world_size), + batch_size=int(opt.batch_size / world_size), + drop_last=True) + else: + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=(not opt.serial_batches) and opt.isTrain, + num_workers=int(opt.num_threads), + drop_last=True + ) + + def set_epoch(self, epoch): + self.dataset.current_epoch = epoch + if self.sampler is not None: + self.sampler.set_epoch(epoch) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/chat_anything/sad_talker/face3d/data/base_dataset.py b/chat_anything/sad_talker/face3d/data/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd57d082d519f512d7114b4f867b6695fb7de06 --- /dev/null +++ b/chat_anything/sad_talker/face3d/data/base_dataset.py @@ -0,0 +1,125 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + # self.root = opt.dataroot + self.current_epoch = 0 + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_transform(grayscale=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + transform_list += [transforms.ToTensor()] + return transforms.Compose(transform_list) + +def get_affine_mat(opt, size): + shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False + w, h = size + + if 'shift' in opt.preprocess: + shift_pixs = int(opt.shift_pixs) + shift_x = random.randint(-shift_pixs, shift_pixs) + shift_y = random.randint(-shift_pixs, shift_pixs) + if 'scale' in opt.preprocess: + scale = 1 + opt.scale_delta * (2 * random.random() - 1) + if 'rot' in opt.preprocess: + rot_angle = opt.rot_angle * (2 * random.random() - 1) + rot_rad = -rot_angle * np.pi/180 + if 'flip' in opt.preprocess: + flip = random.random() > 0.5 + + shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) + flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) + shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) + rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) + scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) + shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) + + affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin + affine_inv = np.linalg.inv(affine) + return affine, affine_inv, flip + +def apply_img_affine(img, affine_inv, method=Image.BICUBIC): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) + +def apply_lm_affine(landmark, affine, flip, size): + _, h = size + lm = landmark.copy() + lm[:, 1] = h - 1 - lm[:, 1] + lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) + lm = lm @ np.transpose(affine) + lm[:, :2] = lm[:, :2] / lm[:, 2:] + lm = lm[:, :2] + lm[:, 1] = h - 1 - lm[:, 1] + if flip: + lm_ = lm.copy() + lm_[:17] = lm[16::-1] + lm_[17:22] = lm[26:21:-1] + lm_[22:27] = lm[21:16:-1] + lm_[31:36] = lm[35:30:-1] + lm_[36:40] = lm[45:41:-1] + lm_[40:42] = lm[47:45:-1] + lm_[42:46] = lm[39:35:-1] + lm_[46:48] = lm[41:39:-1] + lm_[48:55] = lm[54:47:-1] + lm_[55:60] = lm[59:54:-1] + lm_[60:65] = lm[64:59:-1] + lm_[65:68] = lm[67:64:-1] + lm = lm_ + return lm diff --git a/chat_anything/sad_talker/face3d/data/flist_dataset.py b/chat_anything/sad_talker/face3d/data/flist_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b6945c80aa756074a5d3c02b9443b15ddcfc57 --- /dev/null +++ b/chat_anything/sad_talker/face3d/data/flist_dataset.py @@ -0,0 +1,125 @@ +"""This script defines the custom dataset for Deep3DFaceRecon_pytorch +""" + +import os.path +from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util +import numpy as np +import json +import torch +from scipy.io import loadmat, savemat +import pickle +from util.preprocess import align_img, estimate_norm +from util.load_mats import load_lm3d + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + +def jason_flist_reader(flist): + with open(flist, 'r') as fp: + info = json.load(fp) + return info + +def parse_label(label): + return torch.tensor(np.array(label).astype(np.float32)) + + +class FlistDataset(BaseDataset): + """ + It requires one directories to host training images '/path/to/data/train' + You can train the model with the dataset flag '--dataroot /path/to/data'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + msk_names = default_flist_reader(opt.flist) + self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] + + self.size = len(self.msk_paths) + self.opt = opt + + self.name = 'train' if opt.isTrain else 'val' + if '_' in opt.flist: + self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] + + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + img (tensor) -- an image in the input domain + msk (tensor) -- its corresponding attention mask + lm (tensor) -- its corresponding 3d landmarks + im_paths (str) -- image paths + aug_flag (bool) -- a flag used to tell whether its raw or augmented + """ + msk_path = self.msk_paths[index % self.size] # make sure index is within then range + img_path = msk_path.replace('mask/', '') + lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' + + raw_img = Image.open(img_path).convert('RGB') + raw_msk = Image.open(msk_path).convert('RGB') + raw_lm = np.loadtxt(lm_path).astype(np.float32) + + _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) + + aug_flag = self.opt.use_aug and self.opt.isTrain + if aug_flag: + img, lm, msk = self._augmentation(img, lm, self.opt, msk) + + _, H = img.size + M = estimate_norm(lm, H) + transform = get_transform() + img_tensor = transform(img) + msk_tensor = transform(msk)[:1, ...] + lm_tensor = parse_label(lm) + M_tensor = parse_label(M) + + + return {'imgs': img_tensor, + 'lms': lm_tensor, + 'msks': msk_tensor, + 'M': M_tensor, + 'im_paths': img_path, + 'aug_flag': aug_flag, + 'dataset': self.name} + + def _augmentation(self, img, lm, opt, msk=None): + affine, affine_inv, flip = get_affine_mat(opt, img.size) + img = apply_img_affine(img, affine_inv) + lm = apply_lm_affine(lm, affine, flip, img.size) + if msk is not None: + msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) + return img, lm, msk + + + + + def __len__(self): + """Return the total number of images in the dataset. + """ + return self.size diff --git a/chat_anything/sad_talker/face3d/data/image_folder.py b/chat_anything/sad_talker/face3d/data/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..efadc2ecbe2fb4b53b78230aba25ec505eff0e55 --- /dev/null +++ b/chat_anything/sad_talker/face3d/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" +import numpy as np +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/chat_anything/sad_talker/face3d/data/template_dataset.py b/chat_anything/sad_talker/face3d/data/template_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdf16be2a8a834b204c45d88c86857b37b9bd25 --- /dev/null +++ b/chat_anything/sad_talker/face3d/data/template_dataset.py @@ -0,0 +1,75 @@ +"""Dataset class template + +This module provides a template for users to implement custom datasets. +You can specify '--dataset_mode template' to use this dataset. +The class name should be consistent with both the filename and its dataset_mode option. +The filename should be _dataset.py +The class name should be Dataset.py +You need to implement the following functions: + -- : Add dataset-specific options and rewrite default values for existing options. + -- <__init__>: Initialize this dataset class. + -- <__getitem__>: Return a data point and its metadata information. + -- <__len__>: Return the number of images. +""" +from data.base_dataset import BaseDataset, get_transform +# from data.image_folder import make_dataset +# from PIL import Image + + +class TemplateDataset(BaseDataset): + """A template dataset class for you to implement custom datasets.""" + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') + parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values + return parser + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + # save the option and dataset root + BaseDataset.__init__(self, opt) + # get the image paths of your dataset; + self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root + # define the default transform function. You can use ; You can also define your custom transform function + self.transform = get_transform(opt) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index -- a random integer for data indexing + + Returns: + a dictionary of data with their names. It usually contains the data itself and its metadata information. + + Step 1: get a random image path: e.g., path = self.image_paths[index] + Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). + Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) + Step 4: return a data point as a dictionary. + """ + path = 'temp' # needs to be a string + data_A = None # needs to be a tensor + data_B = None # needs to be a tensor + return {'data_A': data_A, 'data_B': data_B, 'path': path} + + def __len__(self): + """Return the total number of images.""" + return len(self.image_paths) diff --git a/chat_anything/sad_talker/face3d/extract_kp_videos.py b/chat_anything/sad_talker/face3d/extract_kp_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..21616a3b4b5077ffdce99621395237b4edcff58c --- /dev/null +++ b/chat_anything/sad_talker/face3d/extract_kp_videos.py @@ -0,0 +1,108 @@ +import os +import cv2 +import time +import glob +import argparse +import face_alignment +import numpy as np +from PIL import Image +from tqdm import tqdm +from itertools import cycle + +from torch.multiprocessing import Pool, Process, set_start_method + +class KeypointExtractor(): + def __init__(self, device): + self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, + device=device) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + current_kp = self.extract_keypoint(image) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + while True: + try: + keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/chat_anything/sad_talker/face3d/extract_kp_videos_safe.py b/chat_anything/sad_talker/face3d/extract_kp_videos_safe.py new file mode 100644 index 0000000000000000000000000000000000000000..895ef79cba7fbb491fc1131e4d46966d870fd5b5 --- /dev/null +++ b/chat_anything/sad_talker/face3d/extract_kp_videos_safe.py @@ -0,0 +1,162 @@ +import os +import cv2 +import time +import glob +import argparse +import numpy as np +from PIL import Image +import torch +from tqdm import tqdm +from itertools import cycle +from torch.multiprocessing import Pool, Process, set_start_method + +from facexlib.alignment import landmark_98_to_68 +from facexlib.detection import init_detection_model + +from facexlib.utils import load_file_from_url +from chat_anything.sad_talker.face3d.util.my_awing_arch import FAN + +def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): + if model_name == 'awing_fan': + model = FAN(num_modules=4, num_landmarks=98, device=device) + model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url( + url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) + model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True) + model.eval() + model = model.to(device) + return model + + +class KeypointExtractor(): + def __init__(self, device='cuda'): + + ### gfpgan/weights + try: + import webui # in webui + root_path = 'extensions/SadTalker/gfpgan/weights' + + except: + # root_path = 'gfpgan/weights' + root_path = 'MODELS/gfpgan/weights' + + self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) + self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) + + def extract_keypoint(self, images, name=None, info=True): + if isinstance(images, list): + keypoints = [] + if info: + i_range = tqdm(images,desc='landmark Det:') + else: + i_range = images + + for image in i_range: + print("detect landmarks") + current_kp = self.extract_keypoint(image) + # current_kp = self.detector.get_landmarks(np.array(image)) + if np.mean(current_kp) == -1 and keypoints: + keypoints.append(keypoints[-1]) + else: + keypoints.append(current_kp[None]) + + keypoints = np.concatenate(keypoints, 0) + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + else: + print("here") + while True: + try: + with torch.no_grad(): + # face detection -> face alignment. + img = np.array(images) + bboxes = self.det_net.detect_faces(images, 0.97) + + bboxes = bboxes[0] + img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] + + landmarks=self.detector.get_landmarks(img) + print(landmarks.shape) + start_time=time.time() + keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] + end_time=time.time() + print(type(keypoints)) + print(keypoints.shape) + + elapsed_time = end_time - start_time # 计算时间差 + print("landmark检测时间:%.4f秒" % elapsed_time) + #### keypoints to the original location + keypoints[:,0] += int(bboxes[0]) + keypoints[:,1] += int(bboxes[1]) + + break + except RuntimeError as e: + if str(e).startswith('CUDA'): + print("Warning: out of memory, sleep for 1s") + time.sleep(1) + else: + print(e) + break + except TypeError: + print('No face detected in this image') + shape = [68, 2] + keypoints = -1. * np.ones(shape) + break + if name is not None: + np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) + return keypoints + +def read_video(filename): + frames = [] + cap = cv2.VideoCapture(filename) + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + else: + break + cap.release() + return frames + +def run(data): + filename, opt, device = data + os.environ['CUDA_VISIBLE_DEVICES'] = device + kp_extractor = KeypointExtractor() + images = read_video(filename) + name = filename.split('/')[-2:] + os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) + kp_extractor.extract_keypoint( + images, + name=os.path.join(opt.output_dir, name[-2], name[-1]) + ) + +if __name__ == '__main__': + set_start_method('spawn') + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--output_dir', type=str, help='the folder of the output files') + parser.add_argument('--device_ids', type=str, default='0,1') + parser.add_argument('--workers', type=int, default=4) + + opt = parser.parse_args() + filenames = list() + VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} + VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) + extensions = VIDEO_EXTENSIONS + + for ext in extensions: + os.listdir(f'{opt.input_dir}') + print(f'{opt.input_dir}/*.{ext}') + filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) + print('Total number of videos:', len(filenames)) + pool = Pool(opt.workers) + args_list = cycle([opt]) + device_ids = opt.device_ids.split(",") + device_ids = cycle(device_ids) + for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): + None diff --git a/chat_anything/sad_talker/face3d/models/__init__.py b/chat_anything/sad_talker/face3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7a1d26e18c1e6f3cd70fcb818c02ee6150c1de --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from chat_anything.sad_talker.face3d.models.base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "face3d.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/README.md b/chat_anything/sad_talker/face3d/models/arcface_torch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2ee63a861229b68873561fa39bfa7c9a8b53b947 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/README.md @@ -0,0 +1,164 @@ +# Distributed Arcface Training in Pytorch + +This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions +identity on a single server. + +## Requirements + +- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md). +- `pip install -r requirements.txt`. +- Download the dataset + from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) + . + +## How to Training + +To train a model, run `train.py` with the path to the configs: + +### 1. Single node, 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +``` + +### 2. Multiple nodes, each node 8 GPUs: + +Node 0: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +Node 1: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50 +``` + +### 3.Training resnet2060 with 8 GPUs: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py +``` + +## Model Zoo + +- The models are available for non-commercial research purposes only. +- All models can be found in here. +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) + +### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/) + +ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face +recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. +As the result, we can evaluate the FAIR performance for different algorithms. + +For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The +globalised multi-racial testset contains 242,143 identities and 1,624,305 images. + +For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). +Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. +There are totally 13,928 positive pairs and 96,983,824 negative pairs. + +| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** | +| :---: | :--- | :--- | :--- |:--- |:--- | +| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** | +| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** | +| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** | +| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** | +| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** | +| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** | +| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** | +| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** | +| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** | +| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** | + +### Performance on IJB-C and Verification Datasets + +| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log | +| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- | +| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)| +| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)| +| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)| +| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)| +| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)| +| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| +| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| +| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| +| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)| + +[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.) + + +## [Speed Benchmark](docs/speed_benchmark.md) + +**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of +classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same +accuracy with several times faster training performance and smaller GPU memory. +Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a +sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a +sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, +we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed +training and mixed precision training. + +![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png) + +More details see +[speed_benchmark.md](docs/speed_benchmark.md) in docs. + +### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better) + +`-` means training failed because of gpu memory limitations. + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|1400000 | **1672** | 3043 | 4738 | +|5500000 | **-** | **1389** | 3975 | +|8000000 | **-** | **-** | 3565 | +|16000000 | **-** | **-** | 2679 | +|29000000 | **-** | **-** | **1855** | + +### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|1400000 | 32252 | 11178 | 6056 | +|5500000 | **-** | 32188 | 9854 | +|8000000 | **-** | **-** | 12310 | +|16000000 | **-** | **-** | 19950 | +|29000000 | **-** | **-** | 32324 | + +## Evaluation ICCV2021-MFR and IJB-C + +More details see [eval.md](docs/eval.md) in docs. + +## Test + +We tested many versions of PyTorch. Please create an issue if you are having trouble. + +- [x] torch 1.6.0 +- [x] torch 1.7.1 +- [x] torch 1.8.0 +- [x] torch 1.9.0 + +## Citation + +``` +@inproceedings{deng2019arcface, + title={Arcface: Additive angular margin loss for deep face recognition}, + author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={4690--4699}, + year={2019} +} +@inproceedings{an2020partical_fc, + title={Partial FC: Training 10 Million Identities on a Single Machine}, + author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and + Zhang, Debing and Fu Ying}, + booktitle={Arxiv 2010.05222}, + year={2020} +} +``` diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..55bd4c5d1889a1a998b52eb56793bbc1eef1b691 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/__init__.py @@ -0,0 +1,25 @@ +from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 +from .mobilefacenet import get_mbf + + +def get_model(name, **kwargs): + # resnet + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + elif name == "mbf": + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf(fp16=fp16, num_features=num_features) + else: + raise ValueError() \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d3b9c240c24687d432197f976ee01fbf423216 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet.py @@ -0,0 +1,187 @@ +import torch +from torch import nn + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) + diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py new file mode 100644 index 0000000000000000000000000000000000000000..21d1122144d207637d2444cba1f68fe630c89f31 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/iresnet2060.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +assert torch.__version__ >= "1.8.1" +from torch.utils.checkpoint import checkpoint_sequential + +__all__ = ['iresnet2060'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def checkpoint(self, func, num_seg, x): + if self.training: + return checkpoint_sequential(func, num_seg, x) + else: + return func(x) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.checkpoint(self.layer2, 20, x) + x = self.checkpoint(self.layer3, 100, x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet2060(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py new file mode 100644 index 0000000000000000000000000000000000000000..87731491d76f9ff61cc70e57bb3f18c54fae308c --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/backbones/mobilefacenet.py @@ -0,0 +1,130 @@ +''' +Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py +Original author cavalleria +''' + +import torch.nn as nn +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module +import torch + + +class Flatten(Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ConvBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(ConvBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), + BatchNorm2d(num_features=out_c), + PReLU(num_parameters=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), + BatchNorm2d(num_features=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class DepthWise(Module): + def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): + super(DepthWise, self).__init__() + self.residual = residual + self.layers = nn.Sequential( + ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), + ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), + LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + ) + + def forward(self, x): + short_cut = None + if self.residual: + short_cut = x + x = self.layers(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) + self.layers = Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class GDC(Module): + def __init__(self, embedding_size): + super(GDC, self).__init__() + self.layers = nn.Sequential( + LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), + Flatten(), + Linear(512, embedding_size, bias=False), + BatchNorm1d(embedding_size)) + + def forward(self, x): + return self.layers(x) + + +class MobileFaceNet(Module): + def __init__(self, fp16=False, num_features=512): + super(MobileFaceNet, self).__init__() + scale = 2 + self.fp16 = fp16 + self.layers = nn.Sequential( + ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), + ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), + DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), + Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), + Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), + Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ) + self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.features = GDC(num_features) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.layers(x) + x = self.conv_sep(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def get_mbf(fp16, num_features): + return MobileFaceNet(fp16, num_features) \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py new file mode 100644 index 0000000000000000000000000000000000000000..c9edc2f1414e35f93abfd3dfe11a61f1f406580e --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py new file mode 100644 index 0000000000000000000000000000000000000000..77caafdbb300d8109d5bfdb844f131710ef81f20 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/3millions_pfc.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 300 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py new file mode 100644 index 0000000000000000000000000000000000000000..78e4b36a9142b649ec39a8c59331bb2557f2ad57 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/base.py @@ -0,0 +1,56 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = "ms1mv3_arcface_r50" + +config.dataset = "ms1m-retinaface-t1" +config.embedding_size = 512 +config.sample_rate = 1 +config.fp16 = False +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +if config.dataset == "emore": + config.rec = "/train_tmp/faces_emore" + config.num_classes = 85742 + config.num_image = 5822653 + config.num_epoch = 16 + config.warmup_epoch = -1 + config.decay_epoch = [8, 14, ] + config.val_targets = ["lfw", ] + +elif config.dataset == "ms1m-retinaface-t1": + config.rec = "/train_tmp/ms1m-retinaface-t1" + config.num_classes = 93431 + config.num_image = 5179510 + config.num_epoch = 25 + config.warmup_epoch = -1 + config.decay_epoch = [11, 17, 22] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "glint360k": + config.rec = "/train_tmp/glint360k" + config.num_classes = 360232 + config.num_image = 17091657 + config.num_epoch = 20 + config.warmup_epoch = -1 + config.decay_epoch = [8, 12, 15, 18] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] + +elif config.dataset == "webface": + config.rec = "/train_tmp/faces_webface_112x112" + config.num_classes = 10572 + config.num_image = "forget" + config.num_epoch = 34 + config.warmup_epoch = -1 + config.decay_epoch = [20, 28, 32] + config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..46ae777cc97af41a531cba4e5d1ff31f2efcb468 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py new file mode 100644 index 0000000000000000000000000000000000000000..93d0701c0094517cec147c382b005e8063938548 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r100.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r18.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8db34cd547e8e667103c93585296e47a894e97 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r34.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..fda2701758a839a7161d09c25f0ca3d26033baff --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r50.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..37e7922f1f63284e356dcc45a5f979f9c105f25e --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/glint360k_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "cosface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = -1 +config.decay_epoch = [8, 12, 15, 18] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_mbf.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a00d6305eeda5a94788017afc1cda0d4a4cd2a --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_mbf.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 2e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 20, 25] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r18.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4e0d31f1aedf4590628d394e1606920fefb5c9 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r18.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r2060.py new file mode 100644 index 0000000000000000000000000000000000000000..23ad81e082c4b6390b67b164d0ceb84bb0635684 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r2060.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r2060" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 64 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r34.py new file mode 100644 index 0000000000000000000000000000000000000000..5f78337a3d1f9eb6e9145eb5093618796c6842d2 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r34.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r34" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..08ba55dbbea6df0afffddbb3d1ed173efad99604 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/ms1mv3_r50.py @@ -0,0 +1,26 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 25 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/configs/speed.py b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/speed.py new file mode 100644 index 0000000000000000000000000000000000000000..45e95237da65e44f35a172c25ac6dc4e313e4eae --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/configs/speed.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.loss = "arcface" +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 100 * 10000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.decay_epoch = [10, 16, 22] +config.val_targets = [] diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/dataset.py b/chat_anything/sad_talker/face3d/models/arcface_torch/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..96bbb8bb6da99122f350bc8e1a6390245840e32b --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/dataset.py @@ -0,0 +1,124 @@ +import numbers +import os +import queue as Queue +import threading + +import mxnet as mx +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + self.transform = transforms.Compose( + [transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + def __getitem__(self, index): + idx = self.imgidx[index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + if self.transform is not None: + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.imgidx) + + +class SyntheticDataset(Dataset): + def __init__(self, local_rank): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.label = 1 + + def __getitem__(self, index): + return self.img, self.label + + def __len__(self): + return 1000000 diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/docs/eval.md b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/eval.md new file mode 100644 index 0000000000000000000000000000000000000000..dd1d9e257367b6422680966198646c45e5a2671d --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/eval.md @@ -0,0 +1,31 @@ +## Eval on ICCV2021-MFR + +coming soon. + + +## Eval IJBC +You can eval ijbc with pytorch or onnx. + + +1. Eval IJBC With Onnx +```shell +CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 +``` + +2. Eval IJBC With Pytorch +```shell +CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ +--model-prefix ms1mv3_arcface_r50/backbone.pth \ +--image-path IJB_release/IJBC \ +--result-dir ms1mv3_arcface_r50 \ +--batch-size 128 \ +--job ms1mv3_arcface_r50 \ +--target IJBC \ +--network iresnet50 +``` + +## Inference + +```shell +python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 +``` diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/docs/install.md b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/install.md new file mode 100644 index 0000000000000000000000000000000000000000..6314a40441285e9236438e468caf8b71a407531a --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/install.md @@ -0,0 +1,51 @@ +## v1.8.0 +### Linux and Windows +```shell +# CUDA 11.0 +pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 + +# CPU only +pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html + +``` + + +## v1.7.1 +### Linux and Windows +```shell +# CUDA 11.0 +pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 10.2 +pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 + +# CUDA 10.1 +pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html +``` + + +## v1.6.0 + +### Linux and Windows +```shell +# CUDA 10.2 +pip install torch==1.6.0 torchvision==0.7.0 + +# CUDA 10.1 +pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html + +# CUDA 9.2 +pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html + +# CPU only +pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html +``` \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/docs/modelzoo.md b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/modelzoo.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/docs/speed_benchmark.md b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/speed_benchmark.md new file mode 100644 index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/docs/speed_benchmark.md @@ -0,0 +1,93 @@ +## Test Training Speed + +- Test Commands + +You need to use the following two commands to test the Partial FC training performance. +The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, +batch size is 1024. +```shell +# Model Parallel +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions +# Partial FC 0.1 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc +``` + +- GPU Memory + +``` +# (Model Parallel) gpustat -i +[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB +[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB +[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB +[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB +[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB +[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB +[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB +[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB + +# (Partial FC 0.1) gpustat -i +[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· +[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· +[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· +[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· +[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· +[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· +[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· +[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· +``` + +- Training Speed + +```python +# (Model Parallel) trainging.log +Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 + +# (Partial FC 0.1) trainging.log +Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 +``` + +In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, +and the training speed is 2.5 times faster than the model parallel. + + +## Speed Benchmark + +1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|250000 | 4047 | 4521 | 4976 | +|500000 | 3087 | 4013 | 4900 | +|1000000 | 2090 | 3449 | 4803 | +|1400000 | 1672 | 3043 | 4738 | +|2000000 | - | 2593 | 4626 | +|4000000 | - | 1748 | 4208 | +|5500000 | - | 1389 | 3975 | +|8000000 | - | - | 3565 | +|16000000 | - | - | 2679 | +|29000000 | - | - | 1855 | + +2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|250000 | 9940 | 5826 | 5004 | +|500000 | 14220 | 7114 | 5202 | +|1000000 | 23708 | 9966 | 5620 | +|1400000 | 32252 | 11178 | 6056 | +|2000000 | - | 13978 | 6472 | +|4000000 | - | 23238 | 8284 | +|5500000 | - | 32188 | 9854 | +|8000000 | - | - | 12310 | +|16000000 | - | - | 19950 | +|29000000 | - | - | 32324 | diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/eval/__init__.py b/chat_anything/sad_talker/face3d/models/arcface_torch/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/eval/verification.py b/chat_anything/sad_talker/face3d/models/arcface_torch/eval/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..253343b83dbf9d1bd154d14ec068e098bf0968db --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/eval/verification.py @@ -0,0 +1,407 @@ +"""Helper for evaluation on the Labeled Faces in the Wild dataset +""" + +# MIT License +# +# Copyright (c) 2016 David Sandberg +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import datetime +import os +import pickle + +import mxnet as mx +import numpy as np +import sklearn +import torch +from mxnet import ndarray as nd +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold + + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list + + +def dumpR(data_set, + backbone, + batch_size, + name='', + data_extra=None, + label_shape=None): + print('dump verification embedding..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + + _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) + time0 = datetime.datetime.now() + if data_extra is None: + db = mx.io.DataBatch(data=(_data,), label=(_label,)) + else: + db = mx.io.DataBatch(data=(_data, _data_extra), + label=(_label,)) + model.forward(db, is_train=False) + net_out = model.get_outputs() + _embeddings = net_out[0].asnumpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + actual_issame = np.asarray(issame_list) + outname = os.path.join('temp.bin') + with open(outname, 'wb') as f: + pickle.dump((embeddings, issame_list), + f, + protocol=pickle.HIGHEST_PROTOCOL) + + +# if __name__ == '__main__': +# +# parser = argparse.ArgumentParser(description='do verification') +# # general +# parser.add_argument('--data-dir', default='', help='') +# parser.add_argument('--model', +# default='../model/softmax,50', +# help='path to load model.') +# parser.add_argument('--target', +# default='lfw,cfp_ff,cfp_fp,agedb_30', +# help='test targets.') +# parser.add_argument('--gpu', default=0, type=int, help='gpu id') +# parser.add_argument('--batch-size', default=32, type=int, help='') +# parser.add_argument('--max', default='', type=str, help='') +# parser.add_argument('--mode', default=0, type=int, help='') +# parser.add_argument('--nfolds', default=10, type=int, help='') +# args = parser.parse_args() +# image_size = [112, 112] +# print('image_size', image_size) +# ctx = mx.gpu(args.gpu) +# nets = [] +# vec = args.model.split(',') +# prefix = args.model.split(',')[0] +# epochs = [] +# if len(vec) == 1: +# pdir = os.path.dirname(prefix) +# for fname in os.listdir(pdir): +# if not fname.endswith('.params'): +# continue +# _file = os.path.join(pdir, fname) +# if _file.startswith(prefix): +# epoch = int(fname.split('.')[0].split('-')[1]) +# epochs.append(epoch) +# epochs = sorted(epochs, reverse=True) +# if len(args.max) > 0: +# _max = [int(x) for x in args.max.split(',')] +# assert len(_max) == 2 +# if len(epochs) > _max[1]: +# epochs = epochs[_max[0]:_max[1]] +# +# else: +# epochs = [int(x) for x in vec[1].split('|')] +# print('model number', len(epochs)) +# time0 = datetime.datetime.now() +# for epoch in epochs: +# print('loading', prefix, epoch) +# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) +# all_layers = sym.get_internals() +# sym = all_layers['fc1_output'] +# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) +# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], +# image_size[1]))]) +# model.set_params(arg_params, aux_params) +# nets.append(model) +# time_now = datetime.datetime.now() +# diff = time_now - time0 +# print('model loading time', diff.total_seconds()) +# +# ver_list = [] +# ver_name_list = [] +# for name in args.target.split(','): +# path = os.path.join(args.data_dir, name + ".bin") +# if os.path.exists(path): +# print('loading.. ', name) +# data_set = load_bin(path, image_size) +# ver_list.append(data_set) +# ver_name_list.append(name) +# +# if args.mode == 0: +# for i in range(len(ver_list)): +# results = [] +# for model in nets: +# acc1, std1, acc2, std2, xnorm, embeddings_list = test( +# ver_list[i], model, args.batch_size, args.nfolds) +# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) +# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) +# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) +# results.append(acc2) +# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) +# elif args.mode == 1: +# raise ValueError +# else: +# model = nets[0] +# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/eval_ijbc.py b/chat_anything/sad_talker/face3d/models/arcface_torch/eval_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5a650d486d18eb02d6f60d448fc3b315261f5d --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/eval_ijbc.py @@ -0,0 +1,483 @@ +# coding: utf-8 + +import os +import pickle + +import matplotlib +import pandas as pd + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import timeit +import sklearn +import argparse +import cv2 +import numpy as np +import torch +from skimage import transform as trans +from backbones import get_model +from sklearn.metrics import roc_curve, auc + +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from pathlib import Path + +import sys +import warnings + +sys.path.insert(0, "../") +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser(description='do ijb test') +# general +parser.add_argument('--model-prefix', default='', help='path to load model.') +parser.add_argument('--image-path', default='', type=str, help='') +parser.add_argument('--result-dir', default='.', type=str, help='') +parser.add_argument('--batch-size', default=128, type=int, help='') +parser.add_argument('--network', default='iresnet50', type=str, help='') +parser.add_argument('--job', default='insightface', type=str, help='job name') +parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') +args = parser.parse_args() + +target = args.target +model_path = args.model_prefix +image_path = args.image_path +result_dir = args.result_dir +gpu_id = None +use_norm_score = True # if Ture, TestMode(N1) +use_detector_score = True # if Ture, TestMode(D1) +use_flip_test = True # if Ture, TestMode(F1) +job = args.job +batch_size = args.batch_size + + +class Embedding(object): + def __init__(self, prefix, data_shape, batch_size=1): + image_size = (112, 112) + self.image_size = image_size + weight = torch.load(prefix) + resnet = get_model(args.network, dropout=0, fp16=False).cuda() + resnet.load_state_dict(weight) + model = torch.nn.DataParallel(resnet) + self.model = model + self.model.eval() + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + src[:, 0] += 8.0 + self.src = src + self.batch_size = batch_size + self.data_shape = data_shape + + def get(self, rimg, landmark): + + assert landmark.shape[0] == 68 or landmark.shape[0] == 5 + assert landmark.shape[1] == 2 + if landmark.shape[0] == 68: + landmark5 = np.zeros((5, 2), dtype=np.float32) + landmark5[0] = (landmark[36] + landmark[39]) / 2 + landmark5[1] = (landmark[42] + landmark[45]) / 2 + landmark5[2] = landmark[30] + landmark5[3] = landmark[48] + landmark5[4] = landmark[54] + else: + landmark5 = landmark + tform = trans.SimilarityTransform() + tform.estimate(landmark5, self.src) + M = tform.params[0:2, :] + img = cv2.warpAffine(rimg, + M, (self.image_size[1], self.image_size[0]), + borderValue=0.0) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_flip = np.fliplr(img) + img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB + img_flip = np.transpose(img_flip, (2, 0, 1)) + input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) + input_blob[0] = img + input_blob[1] = img_flip + return input_blob + + @torch.no_grad() + def forward_db(self, batch_data): + imgs = torch.Tensor(batch_data).cuda() + imgs.div_(255).sub_(0.5).div_(0.5) + feat = self.model(imgs) + feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) + return feat.cpu().numpy() + + +# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] +def divideIntoNstrand(listTemp, n): + twoList = [[] for i in range(n)] + for i, e in enumerate(listTemp): + twoList[i % n].append(e) + return twoList + + +def read_template_media_list(path): + # ijb_meta = np.loadtxt(path, dtype=str) + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +# In[ ]: + + +def read_template_pair_list(path): + # pairs = np.loadtxt(path, dtype=str) + pairs = pd.read_csv(path, sep=' ', header=None).values + # print(pairs.shape) + # print(pairs[:, 0].astype(np.int)) + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +# In[ ]: + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# In[ ]: + + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, batch_size) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + + batch_data = np.empty((2 * rare_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, rare_size) + for img_index, each_line in enumerate(files[len(files) - rare_size:]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + batch_data[2 * img_index][:] = input_blob[0] + batch_data[2 * img_index + 1][:] = input_blob[1] + if (img_index + 1) % rare_size == 0: + print('batch', batch) + img_feats[len(files) - + rare_size:][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 + # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) + return img_feats, faceness_scores + + +# In[ ]: + + +def image2template_feature(img_feats=None, templates=None, medias=None): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + for count_template, uqt in enumerate(unique_templates): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, + return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +# In[ ]: + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +# In[ ]: +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def read_score(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# # Step1: Load Meta Data + +# In[ ]: + +assert target == 'IJBC' or target == 'IJBB' + +# ============================================================= +# load image and template relationships for template feature embedding +# tid --> template id, mid --> media id +# format: +# image_name tid mid +# ============================================================= +start = timeit.default_timer() +templates, medias = read_template_media_list( + os.path.join('%s/meta' % image_path, + '%s_face_tid_mid.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: + +# ============================================================= +# load template pairs for template-to-template verification +# tid : template id, label : 1/0 +# format: +# tid_1 tid_2 label +# ============================================================= +start = timeit.default_timer() +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 2: Get Image Features + +# In[ ]: + +# ============================================================= +# load image features +# format: +# img_feats: [image_num x feats_dim] (227630, 512) +# ============================================================= +start = timeit.default_timer() +img_path = '%s/loose_crop' % image_path +img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) +img_list = open(img_list_path) +files = img_list.readlines() +# files_list = divideIntoNstrand(files, rank_size) +files_list = files + +# img_feats +# for i in range(rank_size): +img_feats, faceness_scores = get_image_feature(img_path, files_list, + model_path, 0, gpu_id) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) +print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], + img_feats.shape[1])) + +# # Step3: Get Template Features + +# In[ ]: + +# ============================================================= +# compute template features from image features. +# ============================================================= +start = timeit.default_timer() +# ========================================================== +# Norm feature before aggregation into template feature? +# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). +# ========================================================== +# 1. FaceScore (Feature Norm) +# 2. FaceScore (Detector) + +if use_flip_test: + # concat --- F1 + # img_input_feats = img_feats + # add --- F2 + img_input_feats = img_feats[:, 0:img_feats.shape[1] // + 2] + img_feats[:, img_feats.shape[1] // 2:] +else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + +if use_norm_score: + img_input_feats = img_input_feats +else: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt( + np.sum(img_input_feats ** 2, -1, keepdims=True)) + +if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] +else: + img_input_feats = img_input_feats + +template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 4: Get Template Similarity Scores + +# In[ ]: + +# ============================================================= +# compute verification scores between template pairs. +# ============================================================= +start = timeit.default_timer() +score = verification(template_norm_feats, unique_templates, p1, p2) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: +save_path = os.path.join(result_dir, args.job) +# save_path = result_dir + '/%s_result' % target + +if not os.path.exists(save_path): + os.makedirs(save_path) + +score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) +np.save(score_save_file, score) + +# # Step 5: Get ROC Curves and TPR@FPR Table + +# In[ ]: + +files = [score_save_file] +methods = [] +scores = [] +for file in files: + methods.append(Path(file).stem) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) +print(tpr_fpr_table) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/inference.py b/chat_anything/sad_talker/face3d/models/arcface_torch/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5156e8d649954837e397c2ff15ec29995e7502 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/inference.py @@ -0,0 +1,35 @@ +import argparse + +import cv2 +import numpy as np +import torch + +from backbones import get_model + + +@torch.no_grad() +def inference(weight, name, img): + if img is None: + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + else: + img = cv2.imread(img) + img = cv2.resize(img, (112, 112)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + net = get_model(name, fp16=False) + net.load_state_dict(torch.load(weight)) + net.eval() + feat = net(img).numpy() + print(feat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('--network', type=str, default='r50', help='backbone network') + parser.add_argument('--weight', type=str, default='') + parser.add_argument('--img', type=str, default=None) + args = parser.parse_args() + inference(args.weight, args.network, args.img) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/losses.py b/chat_anything/sad_talker/face3d/models/arcface_torch/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..87aeaa107af4d53f5a6132b3739d5cafdcded7fc --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/losses.py @@ -0,0 +1,42 @@ +import torch +from torch import nn + + +def get_loss(name): + if name == "cosface": + return CosFace() + elif name == "arcface": + return ArcFace() + else: + raise ValueError() + + +class CosFace(nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine[index] -= m_hot + ret = cosine * self.s + return ret + + +class ArcFace(nn.Module): + def __init__(self, s=64.0, m=0.5): + super(ArcFace, self).__init__() + self.s = s + self.m = m + + def forward(self, cosine: torch.Tensor, label): + index = torch.where(label != -1)[0] + m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) + m_hot.scatter_(1, label[index, None], self.m) + cosine.acos_() + cosine[index] += m_hot + cosine.cos_().mul_(self.s) + return cosine diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_helper.py b/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ca922ca6d410655029e459cf8fd1c323d276c34c --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_helper.py @@ -0,0 +1,250 @@ +from __future__ import division +import datetime +import os +import os.path as osp +import glob +import numpy as np +import cv2 +import sys +import onnxruntime +import onnx +import argparse +from onnx import numpy_helper +from insightface.data import get_image + +class ArcFaceORT: + def __init__(self, model_path, cpu=False): + self.model_path = model_path + # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" + self.providers = ['CPUExecutionProvider'] if cpu else None + + #input_size is (w,h), return error message, return None if success + def check(self, track='cfat', test_img = None): + #default is cfat + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=15 + if track.startswith('ms1m'): + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=10 + elif track.startswith('glint'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=20 + elif track.startswith('cfat'): + max_model_size_mb = 1024 + max_feat_dim = 512 + max_time_cost = 15 + elif track.startswith('unconstrained'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=30 + else: + return "track not found" + + if not os.path.exists(self.model_path): + return "model_path not exists" + if not os.path.isdir(self.model_path): + return "model_path should be directory" + onnx_files = [] + for _file in os.listdir(self.model_path): + if _file.endswith('.onnx'): + onnx_files.append(osp.join(self.model_path, _file)) + if len(onnx_files)==0: + return "do not have onnx files" + self.model_file = sorted(onnx_files)[-1] + print('use onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('input-shape:', input_shape) + if len(input_shape)!=4: + return "length of input_shape should be 4" + if not isinstance(input_shape[0], str): + #return "input_shape[0] should be str to support batch-inference" + print('reset input-shape[0] to None') + model = onnx.load(self.model_file) + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') + onnx.save(model, new_model_file) + self.model_file = new_model_file + print('use new onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('new-input-shape:', input_shape) + + self.image_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + outputs = session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + #print(o.name, o.shape) + if len(output_names)!=1: + return "number of output nodes should be 1" + self.session = session + self.input_name = input_name + self.output_names = output_names + #print(self.output_names) + model = onnx.load(self.model_file) + graph = model.graph + if len(graph.node)<8: + return "too small onnx graph" + + input_size = (112,112) + self.crop = None + if track=='cfat': + crop_file = osp.join(self.model_path, 'crop.txt') + if osp.exists(crop_file): + lines = open(crop_file,'r').readlines() + if len(lines)!=6: + return "crop.txt should contain 6 lines" + lines = [int(x) for x in lines] + self.crop = lines[:4] + input_size = tuple(lines[4:6]) + if input_size!=self.image_size: + return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) + + self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) + if self.model_size_mb > max_model_size_mb: + return "max model size exceed, given %.3f-MB"%self.model_size_mb + + input_mean = None + input_std = None + if track=='cfat': + pn_file = osp.join(self.model_path, 'pixel_norm.txt') + if osp.exists(pn_file): + lines = open(pn_file,'r').readlines() + if len(lines)!=2: + return "pixel_norm.txt should contain 2 lines" + input_mean = float(lines[0]) + input_std = float(lines[1]) + if input_mean is not None or input_std is not None: + if input_mean is None or input_std is None: + return "please set input_mean and input_std simultaneously" + else: + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): + find_mul = True + if find_sub and find_mul: + print("find sub and mul") + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + for initn in graph.initializer: + weight_array = numpy_helper.to_array(initn) + dt = weight_array.dtype + if dt.itemsize<4: + return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) + if test_img is None: + test_img = get_image('Tom_Hanks_54745') + test_img = cv2.resize(test_img, self.image_size) + else: + test_img = cv2.resize(test_img, self.image_size) + feat, cost = self.benchmark(test_img) + batch_result = self.check_batch(test_img) + batch_result_sum = float(np.sum(batch_result)) + if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: + print(batch_result) + print(batch_result_sum) + return "batch result output contains NaN!" + + if len(feat.shape) < 2: + return "the shape of the feature must be two, but get {}".format(str(feat.shape)) + + if feat.shape[1] > max_feat_dim: + return "max feat dim exceed, given %d"%feat.shape[1] + self.feat_dim = feat.shape[1] + cost_ms = cost*1000 + if cost_ms>max_time_cost: + return "max time cost exceed, given %.4f"%cost_ms + self.cost_ms = cost_ms + print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) + return None + + def check_batch(self, img): + if not isinstance(img, list): + imgs = [img, ] * 32 + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] + if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: + nimg = cv2.resize(nimg, self.image_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages( + images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, + mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + + def meta_info(self): + return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} + + + def forward(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.image_size + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + return net_out + + def benchmark(self, img): + input_size = self.image_size + if self.crop is not None: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + img = nimg + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + costs = [] + for _ in range(50): + ta = datetime.datetime.now() + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + tb = datetime.datetime.now() + cost = (tb-ta).total_seconds() + costs.append(cost) + costs = sorted(costs) + cost = costs[5] + return net_out, cost + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + # general + parser.add_argument('workdir', help='submitted work dir', type=str) + parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') + args = parser.parse_args() + handler = ArcFaceORT(args.workdir) + err = handler.check(args.track) + print('err:', err) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_ijbc.py b/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_ijbc.py new file mode 100644 index 0000000000000000000000000000000000000000..05b50bfad4b4cf38903b89f596263a8e29a50d3e --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/onnx_ijbc.py @@ -0,0 +1,267 @@ +import argparse +import os +import pickle +import timeit + +import cv2 +import mxnet as mx +import numpy as np +import pandas as pd +import prettytable +import skimage.transform +from sklearn.metrics import roc_curve +from sklearn.preprocessing import normalize + +from onnx_helper import ArcFaceORT + +SRC = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]] + , dtype=np.float32) +SRC[:, 0] += 8.0 + + +class AlignedDataSet(mx.gluon.data.Dataset): + def __init__(self, root, lines, align=True): + self.lines = lines + self.root = root + self.align = align + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + each_line = self.lines[idx] + name_lmk_score = each_line.strip().split(' ') + name = os.path.join(self.root, name_lmk_score[0]) + img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) + landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) + st = skimage.transform.SimilarityTransform() + st.estimate(landmark5, SRC) + img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) + img_1 = np.expand_dims(img, 0) + img_2 = np.expand_dims(np.fliplr(img), 0) + output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) + output = np.transpose(output, (0, 3, 1, 2)) + output = mx.nd.array(output) + return output + + +def extract(model_root, dataset): + model = ArcFaceORT(model_path=model_root) + model.check() + feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) + + def batchify_fn(data): + return mx.nd.concat(*data, dim=0) + + data_loader = mx.gluon.data.DataLoader( + dataset, 128, last_batch='keep', num_workers=4, + thread_pool=True, prefetch=16, batchify_fn=batchify_fn) + num_iter = 0 + for batch in data_loader: + batch = batch.asnumpy() + batch = (batch - model.input_mean) / model.input_std + feat = model.session.run(model.output_names, {model.input_name: batch})[0] + feat = np.reshape(feat, (-1, model.feat_dim * 2)) + feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat + num_iter += 1 + if num_iter % 50 == 0: + print(num_iter) + return feat_mat + + +def read_template_media_list(path): + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +def image2template_feature(img_feats=None, + templates=None, + medias=None): + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + for count_template, uqt in enumerate(unique_templates): + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] + media_norm_feats = np.array(media_norm_feats) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + template_norm_feats = normalize(template_feats) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) + total_pairs = np.array(range(len(p1))) + batchsize = 100000 + sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def main(args): + use_norm_score = True # if Ture, TestMode(N1) + use_detector_score = True # if Ture, TestMode(D1) + use_flip_test = True # if Ture, TestMode(F1) + assert args.target == 'IJBC' or args.target == 'IJBB' + + start = timeit.default_timer() + templates, medias = read_template_media_list( + os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % args.image_path, + '%s_template_pair_label.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + img_path = '%s/loose_crop' % args.image_path + img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) + img_list = open(img_list_path) + files = img_list.readlines() + dataset = AlignedDataSet(root=img_path, lines=files, align=True) + img_feats = extract(args.model_root, dataset) + + faceness_scores = [] + for each_line in files: + name_lmk_score = each_line.split() + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) + start = timeit.default_timer() + + if use_flip_test: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] + else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + + if use_norm_score: + img_input_feats = img_input_feats + else: + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + + if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + else: + img_input_feats = img_input_feats + + template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + score = verification(template_norm_feats, unique_templates, p1, p2) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) + if not os.path.exists(save_path): + os.makedirs(save_path) + score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) + np.save(score_save_file, score) + files = [score_save_file] + methods = [] + scores = [] + for file in files: + methods.append(os.path.basename(file)) + scores.append(np.load(file)) + methods = np.array(methods) + scores = dict(zip(methods, scores)) + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) + for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, args.target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) + print(tpr_fpr_table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='do ijb test') + # general + parser.add_argument('--model-root', default='', help='path to load model.') + parser.add_argument('--image-path', default='', type=str, help='') + parser.add_argument('--result-dir', default='.', type=str, help='') + parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') + main(parser.parse_args()) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/partial_fc.py b/chat_anything/sad_talker/face3d/models/arcface_torch/partial_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..17e2d25715d10ba446c957e1d2528b0687ed71d5 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/partial_fc.py @@ -0,0 +1,222 @@ +import logging +import os + +import torch +import torch.distributed as dist +from torch.nn import Module +from torch.nn.functional import normalize, linear +from torch.nn.parameter import Parameter + + +class PartialFC(Module): + """ + Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, + Partial FC: Training 10 Million Identities on a Single Machine + See the original paper: + https://arxiv.org/abs/2010.05222 + """ + + @torch.no_grad() + def __init__(self, rank, local_rank, world_size, batch_size, resume, + margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"): + """ + rank: int + Unique process(GPU) ID from 0 to world_size - 1. + local_rank: int + Unique process(GPU) ID within the server from 0 to 7. + world_size: int + Number of GPU. + batch_size: int + Batch size on current rank(GPU). + resume: bool + Select whether to restore the weight of softmax. + margin_softmax: callable + A function of margin softmax, eg: cosface, arcface. + num_classes: int + The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, + required. + sample_rate: float + The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling + can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. + embedding_size: int + The feature dimension, default is 512. + prefix: str + Path for save checkpoint, default is './'. + """ + super(PartialFC, self).__init__() + # + self.num_classes: int = num_classes + self.rank: int = rank + self.local_rank: int = local_rank + self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) + self.world_size: int = world_size + self.batch_size: int = batch_size + self.margin_softmax: callable = margin_softmax + self.sample_rate: float = sample_rate + self.embedding_size: int = embedding_size + self.prefix: str = prefix + self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) + self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) + self.num_sample: int = int(self.sample_rate * self.num_local) + + self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) + self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) + + if resume: + try: + self.weight: torch.Tensor = torch.load(self.weight_name) + self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) + if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: + raise IndexError + logging.info("softmax weight resume successfully!") + logging.info("softmax weight mom resume successfully!") + except (FileNotFoundError, KeyError, IndexError): + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init!") + logging.info("softmax weight mom init!") + else: + self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) + self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) + logging.info("softmax weight init successfully!") + logging.info("softmax weight mom init successfully!") + self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) + + self.index = None + if int(self.sample_rate) == 1: + self.update = lambda: 0 + self.sub_weight = Parameter(self.weight) + self.sub_weight_mom = self.weight_mom + else: + self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) + + def save_params(self): + """ Save softmax weight for each rank on prefix + """ + torch.save(self.weight.data, self.weight_name) + torch.save(self.weight_mom, self.weight_mom_name) + + @torch.no_grad() + def sample(self, total_label): + """ + Sample all positive class centers in each rank, and random select neg class centers to filling a fixed + `num_sample`. + + total_label: tensor + Label after all gather, which cross all GPUs. + """ + index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) + total_label[~index_positive] = -1 + total_label[index_positive] -= self.class_start + if int(self.sample_rate) != 1: + positive = torch.unique(total_label[index_positive], sorted=True) + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local], device=self.device) + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1] + index = index.sort()[0] + else: + index = positive + self.index = index + total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) + self.sub_weight = Parameter(self.weight[index]) + self.sub_weight_mom = self.weight_mom[index] + + def forward(self, total_features, norm_weight): + """ Partial fc forward, `logits = X * sample(W)` + """ + torch.cuda.current_stream().wait_stream(self.stream) + logits = linear(total_features, norm_weight) + return logits + + @torch.no_grad() + def update(self): + """ Set updated weight and weight_mom to memory bank. + """ + self.weight_mom[self.index] = self.sub_weight_mom + self.weight[self.index] = self.sub_weight + + def prepare(self, label, optimizer): + """ + get sampled class centers for cal softmax. + + label: tensor + Label tensor on each rank. + optimizer: opt + Optimizer for partial fc, which need to get weight mom. + """ + with torch.cuda.stream(self.stream): + total_label = torch.zeros( + size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) + dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) + self.sample(total_label) + optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) + optimizer.param_groups[-1]['params'][0] = self.sub_weight + optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom + norm_weight = normalize(self.sub_weight) + return total_label, norm_weight + + def forward_backward(self, label, features, optimizer): + """ + Partial fc forward and backward with model parallel + + label: tensor + Label tensor on each rank(GPU) + features: tensor + Features tensor on each rank(GPU) + optimizer: optimizer + Optimizer for partial fc + + Returns: + -------- + x_grad: tensor + The gradient of features. + loss_v: tensor + Loss value for cross entropy. + """ + total_label, norm_weight = self.prepare(label, optimizer) + total_features = torch.zeros( + size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) + dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) + total_features.requires_grad = True + + logits = self.forward(total_features, norm_weight) + logits = self.margin_softmax(logits, total_label) + + with torch.no_grad(): + max_fc = torch.max(logits, dim=1, keepdim=True)[0] + dist.all_reduce(max_fc, dist.ReduceOp.MAX) + + # calculate exp(logits) and all-reduce + logits_exp = torch.exp(logits - max_fc) + logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) + dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) + + # calculate prob + logits_exp.div_(logits_sum_exp) + + # get one-hot + grad = logits_exp + index = torch.where(total_label != -1)[0] + one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) + one_hot.scatter_(1, total_label[index, None], 1) + + # calculate loss + loss = torch.zeros(grad.size()[0], 1, device=grad.device) + loss[index] = grad[index].gather(1, total_label[index, None]) + dist.all_reduce(loss, dist.ReduceOp.SUM) + loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) + + # calculate grad + grad[index] -= one_hot + grad.div_(self.batch_size * self.world_size) + + logits.backward(grad) + if total_features.grad is not None: + total_features.grad.detach_() + x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) + # feature gradient all-reduce + dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) + x_grad = x_grad * self.world_size + # backward backbone + return x_grad, loss_v diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/requirement.txt b/chat_anything/sad_talker/face3d/models/arcface_torch/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..f72c1b3ba814ae1e0bc1c1f56402026978b9e870 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/requirement.txt @@ -0,0 +1,5 @@ +tensorboard +easydict +mxnet +onnx +sklearn diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/run.sh b/chat_anything/sad_talker/face3d/models/arcface_torch/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..61af4b4950eb11334e55362e3e3c5e2796979a01 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/run.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 +ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/torch2onnx.py b/chat_anything/sad_talker/face3d/models/arcface_torch/torch2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..fc26ab82e552331bc8d75b34e81000418f4d38ec --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/torch2onnx.py @@ -0,0 +1,59 @@ +import numpy as np +import onnx +import torch + + +def convert_onnx(net, path_module, output, opset=11, simplify=False): + assert isinstance(net, torch.nn.Module) + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = img.astype(np.float) + img = (img / 255. - 0.5) / 0.5 # torch style norm + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + + weight = torch.load(path_module) + net.load_state_dict(weight) + net.eval() + torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) + model = onnx.load(output) + graph = model.graph + graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "backbone.pth") + assert os.path.exists(input_file) + model_name = os.path.basename(os.path.dirname(input_file)).lower() + params = model_name.split("_") + if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + if args.network is None: + args.network = params[2] + assert args.network is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0) + + output_path = args.output + if output_path is None: + output_path = os.path.join(os.path.dirname(__file__), 'onnx') + if not os.path.exists(output_path): + os.makedirs(output_path) + assert os.path.isdir(output_path) + output_file = os.path.join(output_path, "%s.onnx" % model_name) + convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/train.py b/chat_anything/sad_talker/face3d/models/arcface_torch/train.py new file mode 100644 index 0000000000000000000000000000000000000000..55eca2d0ad9463415970e09bccab8b722e496704 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/train.py @@ -0,0 +1,141 @@ +import argparse +import logging +import os + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.data.distributed +from torch.nn.utils import clip_grad_norm_ + +import losses +from backbones import get_model +from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX +from partial_fc import PartialFC +from utils.utils_amp import MaxClipGradScaler +from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint +from utils.utils_config import get_config +from utils.utils_logging import AverageMeter, init_logging + + +def main(args): + cfg = get_config(args.config) + try: + world_size = int(os.environ['WORLD_SIZE']) + rank = int(os.environ['RANK']) + dist.init_process_group('nccl') + except KeyError: + world_size = 1 + rank = 0 + dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) + + local_rank = args.local_rank + torch.cuda.set_device(local_rank) + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + if cfg.rec == "synthetic": + train_set = SyntheticDataset(local_rank=local_rank) + else: + train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) + + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) + train_loader = DataLoaderX( + local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, + sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) + backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) + + if cfg.resume: + try: + backbone_pth = os.path.join(cfg.output, "backbone.pth") + backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) + if rank == 0: + logging.info("backbone resume successfully!") + except (FileNotFoundError, KeyError, IndexError, RuntimeError): + if rank == 0: + logging.info("resume fail, backbone init successfully!") + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank]) + backbone.train() + margin_softmax = losses.get_loss(cfg.loss) + module_partial_fc = PartialFC( + rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, + batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, + sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) + + opt_backbone = torch.optim.SGD( + params=[{'params': backbone.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + opt_pfc = torch.optim.SGD( + params=[{'params': module_partial_fc.parameters()}], + lr=cfg.lr / 512 * cfg.batch_size * world_size, + momentum=0.9, weight_decay=cfg.weight_decay) + + num_image = len(train_set) + total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch + cfg.total_step = num_image // total_batch_size * cfg.num_epoch + + def lr_step_func(current_step): + cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] + if current_step < cfg.warmup_step: + return current_step / cfg.warmup_step + else: + return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) + + scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_backbone, lr_lambda=lr_step_func) + scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( + optimizer=opt_pfc, lr_lambda=lr_step_func) + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + val_target = cfg.val_targets + callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) + callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) + callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) + + loss = AverageMeter() + start_epoch = 0 + global_step = 0 + grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None + for epoch in range(start_epoch, cfg.num_epoch): + train_sampler.set_epoch(epoch) + for step, (img, label) in enumerate(train_loader): + global_step += 1 + features = F.normalize(backbone(img)) + x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) + if cfg.fp16: + features.backward(grad_amp.scale(x_grad)) + grad_amp.unscale_(opt_backbone) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + grad_amp.step(opt_backbone) + grad_amp.update() + else: + features.backward(x_grad) + clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) + opt_backbone.step() + + opt_pfc.step() + module_partial_fc.update() + opt_backbone.zero_grad() + opt_pfc.zero_grad() + loss.update(loss_v, 1) + callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) + callback_verification(global_step, backbone) + scheduler_backbone.step() + scheduler_pfc.step() + callback_checkpoint(global_step, backbone, module_partial_fc) + dist.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('config', type=str, help='py config file') + parser.add_argument('--local_rank', type=int, default=0, help='local_rank') + main(parser.parse_args()) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/__init__.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/plot.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc588e5c01ca550b69c385aeb3fd139c59fb88a --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/plot.py @@ -0,0 +1,72 @@ +# coding: utf-8 + +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +image_path = "/data/anxiang/IJB_release/IJBC" +files = [ + "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" +] + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file.split('/')[-2]) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, "IJBC")) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_amp.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac2a03f4212faa129faed447a8f4519c0a00a8b --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_amp.py @@ -0,0 +1,88 @@ +from typing import Dict, List + +import torch + +if torch.__version__ < '1.9': + Iterable = torch._six.container_abcs.Iterable +else: + import collections + + Iterable = collections.abc.Iterable +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_callbacks.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2f56cba47c57de102710ff56eaac591e59f4da --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_callbacks.py @@ -0,0 +1,117 @@ +import logging +import os +import time +from typing import List + +import torch + +from eval import verification +from utils.utils_logging import AverageMeter + + +class CallBackVerification(object): + def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): + self.frequent: int = frequent + self.rank: int = rank + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): + self.frequent: int = frequent + self.rank: int = rank + self.time_start = time.time() + self.total_step: int = total_step + self.batch_size: int = batch_size + self.world_size: int = world_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, + global_step: int, + loss: AverageMeter, + epoch: int, + fp16: bool, + learning_rate: float, + grad_scaler: torch.cuda.amp.GradScaler): + if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + time_now = (time.time() - self.time_start) / 3600 + time_total = time_now / ((global_step + 1) / self.total_step) + time_for_end = time_total - time_now + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + self.writer.add_scalar('learning_rate', learning_rate, global_step) + self.writer.add_scalar('loss', loss.avg, global_step) + if fp16: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, + grad_scaler.get_scale(), time_for_end + ) + else: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ + "Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end + ) + logging.info(msg) + loss.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() + + +class CallBackModelCheckpoint(object): + def __init__(self, rank, output="./"): + self.rank: int = rank + self.output: str = output + + def __call__(self, global_step, backbone, partial_fc, ): + if global_step > 100 and self.rank == 0: + path_module = os.path.join(self.output, "backbone.pth") + torch.save(backbone.module.state_dict(), path_module) + logging.info("Pytorch Model Saved in '{}'".format(path_module)) + + if global_step > 100 and partial_fc is not None: + partial_fc.save_params() diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_config.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02eaf70fc0140aca7925f621c29a496f491cae --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_config.py @@ -0,0 +1,16 @@ +import importlib +import os.path as osp + + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + config = importlib.import_module("configs.base") + cfg = config.config + config = importlib.import_module("configs.%s" % temp_module_name) + job_cfg = config.config + cfg.update(job_cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + return cfg \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_logging.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..c787b6aae7cd037a4718df44d672b8ffa9e5c249 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_logging.py @@ -0,0 +1,41 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(rank, models_root): + if rank == 0: + log_root = logging.getLogger() + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_os.py b/chat_anything/sad_talker/face3d/models/arcface_torch/utils/utils_os.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/sad_talker/face3d/models/base_model.py b/chat_anything/sad_talker/face3d/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe64a7f739ad8f8cfbf3073a2bf49e1468127fd --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/base_model.py @@ -0,0 +1,316 @@ +"""This script defines the base network model for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.isTrain = False + self.device = torch.device('cpu') + self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.parallel_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def dict_grad_hook_factory(add_func=lambda x: x): + saved_dict = dict() + + def hook_gen(name): + def grad_hook(grad): + saved_vals = add_func(grad) + saved_dict[name] = saved_vals + return grad_hook + return hook_gen, saved_dict + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + load_suffix = opt.epoch + self.load_networks(load_suffix) + + + # self.print_networks(opt.verbose) + + def parallelize(self, convert_sync_batchnorm=True): + if not self.opt.use_ddp: + for name in self.parallel_names: + if isinstance(name, str): + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + else: + for name in self.model_names: + if isinstance(name, str): + module = getattr(self, name) + if convert_sync_batchnorm: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) + + # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. + for name in self.parallel_names: + if isinstance(name, str) and name not in self.model_names: + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + + # put state_dict of optimizer to gpu device + if self.opt.phase != 'test': + if self.opt.continue_train: + for optim in self.optimizers: + for state in optim.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + def data_dependent_initialize(self, data): + pass + + def train(self): + """Make models train mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train() + + def eval(self): + """Make models eval mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self, name='A'): + """ Return image paths that are used to load current data""" + return self.image_paths if name =='A' else self.image_paths_B + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name)[:, :3, ...] + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir) + + save_filename = 'epoch_%s.pth' % (epoch) + save_path = os.path.join(self.save_dir, save_filename) + + save_dict = {} + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, + torch.nn.parallel.DistributedDataParallel): + net = net.module + save_dict[name] = net.state_dict() + + + for i, optim in enumerate(self.optimizers): + save_dict['opt_%02d'%i] = optim.state_dict() + + for i, sched in enumerate(self.schedulers): + save_dict['sched_%02d'%i] = sched.state_dict() + + torch.save(save_dict, save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if self.opt.isTrain and self.opt.pretrained_name is not None: + load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + load_dir = self.save_dir + load_filename = 'epoch_%s.pth' % (epoch) + load_path = os.path.join(load_dir, load_filename) + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % load_path) + + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + net.load_state_dict(state_dict[name]) + + if self.opt.phase != 'test': + if self.opt.continue_train: + print('loading the optim from %s' % load_path) + for i, optim in enumerate(self.optimizers): + optim.load_state_dict(state_dict['opt_%02d'%i]) + + try: + print('loading the sched from %s' % load_path) + for i, sched in enumerate(self.schedulers): + sched.load_state_dict(state_dict['sched_%02d'%i]) + except: + print('Failed to load schedulers, set schedulers according to epoch count manually') + for i, sched in enumerate(self.schedulers): + sched.last_epoch = self.opt.epoch_count - 1 + + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def generate_visuals_for_evaluation(self, data, mode): + return {} diff --git a/chat_anything/sad_talker/face3d/models/bfm.py b/chat_anything/sad_talker/face3d/models/bfm.py new file mode 100644 index 0000000000000000000000000000000000000000..6558d4b45a2aa5c138370566d017f1c6d27e7458 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/bfm.py @@ -0,0 +1,331 @@ +"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import loadmat +from chat_anything.sad_talker.face3d.util.load_mats import transferBFM09 +import os + +def perspective_projection(focal, center): + # return p.T (N, 3) @ (3, 3) + return np.array([ + focal, 0, center, + 0, focal, center, + 0, 0, 1 + ]).reshape([3, 3]).astype(np.float32).transpose() + +class SH: + def __init__(self): + self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] + self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] + + + +class ParametricFaceModel: + def __init__(self, + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ + 0.8, 0, 0, 0, 0, 0, 0, 0, 0 + ]), + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): + + if not os.path.isfile(os.path.join(bfm_folder, default_name)): + transferBFM09(bfm_folder) + + model = loadmat(os.path.join(bfm_folder, default_name)) + # mean face shape. [3*N,1] + self.mean_shape = model['meanshape'].astype(np.float32) + # identity basis. [3*N,80] + self.id_base = model['idBase'].astype(np.float32) + # expression basis. [3*N,64] + self.exp_base = model['exBase'].astype(np.float32) + # mean face texture. [3*N,1] (0-255) + self.mean_tex = model['meantex'].astype(np.float32) + # texture basis. [3*N,80] + self.tex_base = model['texBase'].astype(np.float32) + # face indices for each vertex that lies in. starts from 0. [N,8] + self.point_buf = model['point_buf'].astype(np.int64) - 1 + # vertex indices for each face. starts from 0. [F,3] + self.face_buf = model['tri'].astype(np.int64) - 1 + # vertex indices for 68 landmarks. starts from 0. [68,1] + self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 + + if is_train: + # vertex indices for small face region to compute photometric error. starts from 0. + self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 + # vertex indices for each face from small face region. starts from 0. [f,3] + self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 + # vertex indices for pre-defined skin region to compute reflectance loss + self.skin_mask = np.squeeze(model['skinmask']) + + if recenter: + mean_shape = self.mean_shape.reshape([-1, 3]) + mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) + self.mean_shape = mean_shape.reshape([-1, 1]) + + self.persc_proj = perspective_projection(focal, center) + self.device = 'cpu' + self.camera_distance = camera_distance + self.SH = SH() + self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) + + + def to(self, device): + self.device = device + for key, value in self.__dict__.items(): + if type(value).__module__ == np.__name__: + setattr(self, key, torch.tensor(value).to(device)) + + + def compute_shape(self, id_coeff, exp_coeff): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) + + Parameters: + id_coeff -- torch.tensor, size (B, 80), identity coeffs + exp_coeff -- torch.tensor, size (B, 64), expression coeffs + """ + batch_size = id_coeff.shape[0] + id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) + exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) + face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) + return face_shape.reshape([batch_size, -1, 3]) + + + def compute_texture(self, tex_coeff, normalize=True): + """ + Return: + face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) + + Parameters: + tex_coeff -- torch.tensor, size (B, 80) + """ + batch_size = tex_coeff.shape[0] + face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex + if normalize: + face_texture = face_texture / 255. + return face_texture.reshape([batch_size, -1, 3]) + + + def compute_norm(self, face_shape): + """ + Return: + vertex_norm -- torch.tensor, size (B, N, 3) + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + + v1 = face_shape[:, self.face_buf[:, 0]] + v2 = face_shape[:, self.face_buf[:, 1]] + v3 = face_shape[:, self.face_buf[:, 2]] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = torch.cross(e1, e2, dim=-1) + face_norm = F.normalize(face_norm, dim=-1, p=2) + face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) + + vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) + vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) + return vertex_norm + + + def compute_color(self, face_texture, face_norm, gamma): + """ + Return: + face_color -- torch.tensor, size (B, N, 3), range (0, 1.) + + Parameters: + face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) + face_norm -- torch.tensor, size (B, N, 3), rotated face normal + gamma -- torch.tensor, size (B, 27), SH coeffs + """ + batch_size = gamma.shape[0] + v_num = face_texture.shape[1] + a, c = self.SH.a, self.SH.c + gamma = gamma.reshape([batch_size, 3, 9]) + gamma = gamma + self.init_lit + gamma = gamma.permute(0, 2, 1) + Y = torch.cat([ + a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), + -a[1] * c[1] * face_norm[..., 1:2], + a[1] * c[1] * face_norm[..., 2:], + -a[1] * c[1] * face_norm[..., :1], + a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], + -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], + 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), + -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], + 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) + ], dim=-1) + r = Y @ gamma[..., :1] + g = Y @ gamma[..., 1:2] + b = Y @ gamma[..., 2:] + face_color = torch.cat([r, g, b], dim=-1) * face_texture + return face_color + + + def compute_rotation(self, angles): + """ + Return: + rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat + + Parameters: + angles -- torch.tensor, size (B, 3), radian + """ + + batch_size = angles.shape[0] + ones = torch.ones([batch_size, 1]).to(self.device) + zeros = torch.zeros([batch_size, 1]).to(self.device) + x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([batch_size, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) + + + def to_camera(self, face_shape): + face_shape[..., -1] = self.camera_distance - face_shape[..., -1] + return face_shape + + def to_image(self, face_shape): + """ + Return: + face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + # to image_plane + face_proj = face_shape @ self.persc_proj + face_proj = face_proj[..., :2] / face_proj[..., 2:] + + return face_proj + + + def transform(self, face_shape, rot, trans): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + rot -- torch.tensor, size (B, 3, 3) + trans -- torch.tensor, size (B, 3) + """ + return face_shape @ rot + trans.unsqueeze(1) + + + def get_landmarks(self, face_proj): + """ + Return: + face_lms -- torch.tensor, size (B, 68, 2) + + Parameters: + face_proj -- torch.tensor, size (B, N, 2) + """ + return face_proj[:, self.keypoints] + + def split_coeff(self, coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + def compute_for_render(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + rotation = self.compute_rotation(coef_dict['angle']) + + + face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape_transformed) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + def compute_for_render_woRotation(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + #rotation = self.compute_rotation(coef_dict['angle']) + + + #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm # @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/models/facerecon_model.py b/chat_anything/sad_talker/face3d/models/facerecon_model.py new file mode 100644 index 0000000000000000000000000000000000000000..16a1d76056dbfd110bc439e8cdb5c858438f3583 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/facerecon_model.py @@ -0,0 +1,220 @@ +"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +from chat_anything.sad_talker.face3d.models.base_model import BaseModel +from chat_anything.sad_talker.face3d.models import networks +from chat_anything.sad_talker.face3d.models.bfm import ParametricFaceModel +from chat_anything.sad_talker.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss +from chat_anything.sad_talker.face3d.util import util +from chat_anything.sad_talker.face3d.util.nvdiffrast import MeshRenderer +# from chat_anything.sad_talker.face3d.util.preprocess import estimate_norm_torch + +import trimesh +from scipy.io import savemat + +class FaceReconModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=False): + """ Configures options specific for CUT model + """ + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') + parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth') + parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + + if is_train: + # training parameters + parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') + parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') + parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') + parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') + + + # augmentation parameters + parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') + parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') + parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') + + # loss weights + parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') + parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') + parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') + parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') + parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') + parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') + parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') + parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') + parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') + + opt, _ = parser.parse_known_args() + parser.set_defaults( + focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. + ) + if is_train: + parser.set_defaults( + use_crop_face=True, use_predef_M=False + ) + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + + self.visual_names = ['output_vis'] + self.model_names = ['net_recon'] + self.parallel_names = self.model_names + ['renderer'] + + self.facemodel = ParametricFaceModel( + bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, + is_train=self.isTrain, default_name=opt.bfm_model + ) + + fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi + self.renderer = MeshRenderer( + rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center) + ) + + if self.isTrain: + self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] + + self.net_recog = networks.define_net_recog( + net_recog=opt.net_recog, pretrained_path=opt.net_recog_path + ) + # loss func name: (compute_%s_loss) % loss_name + self.compute_feat_loss = perceptual_loss + self.comupte_color_loss = photo_loss + self.compute_lm_loss = landmark_loss + self.compute_reg_loss = reg_loss + self.compute_reflc_loss = reflectance_loss + + self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) + self.optimizers = [self.optimizer] + self.parallel_names += ['net_recog'] + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + self.input_img = input['imgs'].to(self.device) + self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.trans_m = input['M'].to(self.device) if 'M' in input else None + self.image_paths = input['im_paths'] if 'im_paths' in input else None + + def forward(self, output_coeff, device): + self.facemodel.to(device) + self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ + self.facemodel.compute_for_render(output_coeff) + self.pred_mask, _, self.pred_face = self.renderer( + self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) + + self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) + + + def compute_losses(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + assert self.net_recog.training == False + trans_m = self.trans_m + if not self.opt.use_predef_M: + trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) + + pred_feat = self.net_recog(self.pred_face, trans_m) + gt_feat = self.net_recog(self.input_img, self.trans_m) + self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) + + face_mask = self.pred_mask + if self.opt.use_crop_face: + face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) + + face_mask = face_mask.detach() + self.loss_color = self.opt.w_color * self.comupte_color_loss( + self.pred_face, self.input_img, self.atten_mask * face_mask) + + loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) + self.loss_reg = self.opt.w_reg * loss_reg + self.loss_gamma = self.opt.w_gamma * loss_gamma + + self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) + + self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) + + self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + + self.loss_lm + self.loss_reflc + + + def optimize_parameters(self, isTrain=True): + self.forward() + self.compute_losses() + """Update network weights; it will be called in every training iteration.""" + if isTrain: + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() + + def compute_visuals(self): + with torch.no_grad(): + input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() + output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img + output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() + + if self.gt_lm is not None: + gt_lm_numpy = self.gt_lm.cpu().numpy() + pred_lm_numpy = self.pred_lm.detach().cpu().numpy() + output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') + output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') + + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw, output_vis_numpy), axis=-2) + else: + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw), axis=-2) + + self.output_vis = torch.tensor( + output_vis_numpy / 255., dtype=torch.float32 + ).permute(0, 3, 1, 2).to(self.device) + + def save_mesh(self, name): + + recon_shape = self.pred_vertex # get reconstructed shape + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape = recon_shape.cpu().numpy()[0] + recon_color = self.pred_color + recon_color = recon_color.cpu().numpy()[0] + tri = self.facemodel.face_buf.cpu().numpy() + mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8)) + mesh.export(name) + + def save_coeff(self,name): + + pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_lm = self.pred_lm.cpu().numpy() + pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_coeffs['lm68'] = pred_lm + savemat(name,pred_coeffs) + + + diff --git a/chat_anything/sad_talker/face3d/models/losses.py b/chat_anything/sad_talker/face3d/models/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..09d6a85870af1ef2b857e4a3fdd4b2f7fc991317 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/losses.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +from kornia.geometry import warp_affine +import torch.nn.functional as F + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +### perceptual level loss +class PerceptualLoss(nn.Module): + def __init__(self, recog_net, input_size=112): + super(PerceptualLoss, self).__init__() + self.recog_net = recog_net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + def forward(imageA, imageB, M): + """ + 1 - cosine distance + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order + imageB --same as imageA + """ + + imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) + imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) + + # freeze bn + self.recog_net.eval() + + id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) + id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +def perceptual_loss(id_featureA, id_featureB): + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +### image level loss +def photo_loss(imageA, imageB, mask, eps=1e-6): + """ + l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order + imageB --same as imageA + """ + loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask + loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) + return loss + +def landmark_loss(predict_lm, gt_lm, weight=None): + """ + weighted mse loss + Parameters: + predict_lm --torch.tensor (B, 68, 2) + gt_lm --torch.tensor (B, 68, 2) + weight --numpy.array (1, 68) + """ + if not weight: + weight = np.ones([68]) + weight[28:31] = 20 + weight[-8:] = 20 + weight = np.expand_dims(weight, 0) + weight = torch.tensor(weight).to(predict_lm.device) + loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight + loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) + return loss + + +### regulization +def reg_loss(coeffs_dict, opt=None): + """ + l2 norm without the sqrt, from yu's implementation (mse) + tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss + Parameters: + coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans + + """ + # coefficient regularization to ensure plausible 3d faces + if opt: + w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex + else: + w_id, w_exp, w_tex = 1, 1, 1, 1 + creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ + w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ + w_tex * torch.sum(coeffs_dict['tex'] ** 2) + creg_loss = creg_loss / coeffs_dict['id'].shape[0] + + # gamma regularization to ensure a nearly-monochromatic light + gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) + gamma_mean = torch.mean(gamma, dim=1, keepdims=True) + gamma_loss = torch.mean((gamma - gamma_mean) ** 2) + + return creg_loss, gamma_loss + +def reflectance_loss(texture, mask): + """ + minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo + Parameters: + texture --torch.tensor, (B, N, 3) + mask --torch.tensor, (N), 1 or 0 + + """ + mask = mask.reshape([1, mask.shape[0], 1]) + texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) + loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) + return loss + diff --git a/chat_anything/sad_talker/face3d/models/networks.py b/chat_anything/sad_talker/face3d/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..ead9cdcb8720b845c233de79dc8a8d1668492108 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/networks.py @@ -0,0 +1,521 @@ +"""This script defines deep neural networks for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch.nn.functional as F +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch +from torch import Tensor +import torch.nn as nn +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from .arcface_torch.backbones import get_model +from kornia.geometry import warp_affine + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_net_recon(net_recon, use_last_fc=False, init_path=None): + return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + +def define_net_recog(net_recog, pretrained_path=None): + net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) + net.eval() + return net + +class ReconNetWrapper(nn.Module): + fc_dim=257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): + super(ReconNetWrapper, self).__init__() + self.use_last_fc = use_last_fc + if net_recon not in func_dict: + return NotImplementedError('network [%s] is not implemented', net_recon) + func, last_dim = func_dict[net_recon] + backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) + if init_path and os.path.isfile(init_path): + state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) + backbone.load_state_dict(state_dict) + print("loading init net_recon %s from %s" %(net_recon, init_path)) + self.backbone = backbone + if not use_last_fc: + self.final_layers = nn.ModuleList([ + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz + ]) + for m in self.final_layers: + nn.init.constant_(m.weight, 0.) + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + x = self.backbone(x) + if not self.use_last_fc: + output = [] + for layer in self.final_layers: + output.append(layer(x)) + x = torch.flatten(torch.cat(output, dim=1), 1) + return x + + +class RecogNetWrapper(nn.Module): + def __init__(self, net_recog, pretrained_path=None, input_size=112): + super(RecogNetWrapper, self).__init__() + net = get_model(name=net_recog, fp16=False) + if pretrained_path: + state_dict = torch.load(pretrained_path, map_location='cpu') + net.load_state_dict(state_dict) + print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + for param in net.parameters(): + param.requires_grad = False + self.net = net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + + def forward(self, image, M): + image = self.preprocess(resize_n_crop(image, M, self.input_size)) + id_feature = F.normalize(self.net(image), dim=-1, p=2) + return id_feature + + +# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + use_last_fc: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.use_last_fc = use_last_fc + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if self.use_last_fc: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if self.use_last_fc: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +func_dict = { + 'resnet18': (resnet18, 512), + 'resnet50': (resnet50, 2048) +} diff --git a/chat_anything/sad_talker/face3d/models/template_model.py b/chat_anything/sad_talker/face3d/models/template_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dac7b33d5889777eb63c9882a3b9fa094dcab293 --- /dev/null +++ b/chat_anything/sad_talker/face3d/models/template_model.py @@ -0,0 +1,100 @@ +"""Model class template + +This module provides a template for users to implement custom models. +You can specify '--model template' to use this model. +The class name should be consistent with both the filename and its model option. +The filename should be _dataset.py +The class name should be Dataset.py +It implements a simple image-to-image translation baseline based on regression loss. +Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: + min_ ||netG(data_A) - data_B||_1 +You need to implement the following functions: + : Add model-specific options and rewrite default values for existing options. + <__init__>: Initialize this model class. + : Unpack input data and perform data pre-processing. + : Run forward pass. This will be called by both and . + : Update network weights; it will be called in every training iteration. +""" +import numpy as np +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/chat_anything/sad_talker/face3d/options/__init__.py b/chat_anything/sad_talker/face3d/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90 --- /dev/null +++ b/chat_anything/sad_talker/face3d/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/chat_anything/sad_talker/face3d/options/base_options.py b/chat_anything/sad_talker/face3d/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f921d5a43434ae802a55a0fa3889c4b7ab9f6d --- /dev/null +++ b/chat_anything/sad_talker/face3d/options/base_options.py @@ -0,0 +1,169 @@ +"""This script contains base options for Deep3DFaceRecon_pytorch +""" + +import argparse +import os +from util import util +import numpy as np +import torch +import face3d.models as models +import face3d.data as data + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self, cmd_line=None): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + self.cmd_line = None + if cmd_line is not None: + self.cmd_line = cmd_line.split() + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') + parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') + parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') + parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') + parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') + parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') + parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') + + # model parameters + parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') + + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + if self.cmd_line is None: + opt, _ = parser.parse_known_args() + else: + opt, _ = parser.parse_known_args(self.cmd_line) + + # set cuda visible devices + os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + if self.cmd_line is None: + opt, _ = parser.parse_known_args() # parse again with new defaults + else: + opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults + + # modify dataset-related parser options + if opt.dataset_mode: + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + if self.cmd_line is None: + return parser.parse_args() + else: + return parser.parse_args(self.cmd_line) + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + try: + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + except PermissionError as error: + print("permission error {}".format(error)) + pass + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + gpu_ids.append(id) + opt.world_size = len(gpu_ids) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(gpu_ids[0]) + if opt.world_size == 1: + opt.use_ddp = False + + if opt.phase != 'test': + # set continue_train automatically + if opt.pretrained_name is None: + model_dir = os.path.join(opt.checkpoints_dir, opt.name) + else: + model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) + if os.path.isdir(model_dir): + model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] + if os.path.isdir(model_dir) and len(model_pths) != 0: + opt.continue_train= True + + # update the latest epoch count + if opt.continue_train: + if opt.epoch == 'latest': + epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] + if len(epoch_counts) != 0: + opt.epoch_count = max(epoch_counts) + 1 + else: + opt.epoch_count = int(opt.epoch) + 1 + + + self.print_options(opt) + self.opt = opt + return self.opt diff --git a/chat_anything/sad_talker/face3d/options/inference_options.py b/chat_anything/sad_talker/face3d/options/inference_options.py new file mode 100644 index 0000000000000000000000000000000000000000..c453965959ab4cfb31acbc424f994db68c3d4df5 --- /dev/null +++ b/chat_anything/sad_talker/face3d/options/inference_options.py @@ -0,0 +1,23 @@ +from face3d.options.base_options import BaseOptions + + +class InferenceOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + + parser.add_argument('--input_dir', type=str, help='the folder of the input files') + parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') + parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') + parser.add_argument('--save_split_files', action='store_true', help='save split files or not') + parser.add_argument('--inference_batch_size', type=int, default=8) + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/chat_anything/sad_talker/face3d/options/test_options.py b/chat_anything/sad_talker/face3d/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff3ad142779850d1d5a1640bc00f70d34d4a862 --- /dev/null +++ b/chat_anything/sad_talker/face3d/options/test_options.py @@ -0,0 +1,21 @@ +"""This script contains the test options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/chat_anything/sad_talker/face3d/options/train_options.py b/chat_anything/sad_talker/face3d/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1337bfdd5f372b5c686a91b394a2aadbe5741f44 --- /dev/null +++ b/chat_anything/sad_talker/face3d/options/train_options.py @@ -0,0 +1,53 @@ +"""This script contains the training options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions +from util import util + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # dataset parameters + # for train + parser.add_argument('--data_root', type=str, default='./', help='dataset root') + parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') + parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') + + # for val + parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') + parser.add_argument('--batch_size_val', type=int, default=32) + + + # visualization parameters + parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') + + # training parameters + parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') + + self.isTrain = True + return parser diff --git a/chat_anything/sad_talker/face3d/util/BBRegressorParam_r.mat b/chat_anything/sad_talker/face3d/util/BBRegressorParam_r.mat new file mode 100644 index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084 Binary files /dev/null and b/chat_anything/sad_talker/face3d/util/BBRegressorParam_r.mat differ diff --git a/chat_anything/sad_talker/face3d/util/__init__.py b/chat_anything/sad_talker/face3d/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e507b6cfb3351898050514a279f62e37f92a0830 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/__init__.py @@ -0,0 +1,3 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" +from chat_anything.sad_talker.face3d.util import * + diff --git a/chat_anything/sad_talker/face3d/util/detect_lm68.py b/chat_anything/sad_talker/face3d/util/detect_lm68.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e40997289e17405e1fb6c408d21adce7b626ce --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/detect_lm68.py @@ -0,0 +1,106 @@ +import os +import cv2 +import numpy as np +from scipy.io import loadmat +import tensorflow as tf +from util.preprocess import align_for_lm +from shutil import move + +mean_face = np.loadtxt('util/test_mean_face.txt') +mean_face = mean_face.reshape([68, 2]) + +def save_label(labels, save_path): + np.savetxt(save_path, labels) + +def draw_landmarks(img, landmark, save_name): + landmark = landmark + lm_img = np.zeros([img.shape[0], img.shape[1], 3]) + lm_img[:] = img.astype(np.float32) + landmark = np.round(landmark).astype(np.int32) + + for i in range(len(landmark)): + for j in range(-1, 1): + for k in range(-1, 1): + if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ + img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ + landmark[i, 0]+k > 0 and \ + landmark[i, 0]+k < img.shape[1]: + lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, + :] = np.array([0, 0, 255]) + lm_img = lm_img.astype(np.uint8) + + cv2.imwrite(save_name, lm_img) + + +def load_data(img_name, txt_name): + return cv2.imread(img_name), np.loadtxt(txt_name) + +# create tensorflow graph for landmark detector +def load_lm_graph(graph_filename): + with tf.gfile.GFile(graph_filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='net') + img_224 = graph.get_tensor_by_name('net/input_imgs:0') + output_lm = graph.get_tensor_by_name('net/lm:0') + lm_sess = tf.Session(graph=graph) + + return lm_sess,img_224,output_lm + +# landmark detection +def detect_68p(img_path,sess,input_op,output_op): + print('detecting landmarks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + vis_path = os.path.join(img_path, 'vis') + remove_path = os.path.join(img_path, 'remove') + save_path = os.path.join(img_path, 'landmarks') + if not os.path.isdir(vis_path): + os.makedirs(vis_path) + if not os.path.isdir(remove_path): + os.makedirs(remove_path) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + txt_name = '.'.join(name.split('.')[:-1]) + '.txt' + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + + # if an image does not have detected 5 facial landmarks, remove it from the training list + if not os.path.isfile(full_txt_name): + move(full_image_name, os.path.join(remove_path, name)) + continue + + # load data + img, five_points = load_data(full_image_name, full_txt_name) + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + + # if the alignment fails, remove corresponding image from the training list + if scale == 0: + move(full_txt_name, os.path.join( + remove_path, txt_name)) + move(full_image_name, os.path.join(remove_path, name)) + continue + + # detect landmarks + input_img = np.reshape( + input_img, [1, 224, 224, 3]).astype(np.float32) + landmark = sess.run( + output_op, feed_dict={input_op: input_img}) + + # transform back to original image coordinate + landmark = landmark.reshape([68, 2]) + mean_face + landmark[:, 1] = 223 - landmark[:, 1] + landmark = landmark / scale + landmark[:, 0] = landmark[:, 0] + bbox[0] + landmark[:, 1] = landmark[:, 1] + bbox[1] + landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] + + if i % 100 == 0: + draw_landmarks(img, landmark, os.path.join(vis_path, name)) + save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/chat_anything/sad_talker/face3d/util/generate_list.py b/chat_anything/sad_talker/face3d/util/generate_list.py new file mode 100644 index 0000000000000000000000000000000000000000..943d906781063c3584a7e5b5c784f8aac0694985 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/generate_list.py @@ -0,0 +1,34 @@ +"""This script is to generate training list files for Deep3DFaceRecon_pytorch +""" + +import os + +# save path to training data +def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): + save_path = os.path.join(save_folder, mode) + if not os.path.isdir(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in lms_list]) + + with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in imgs_list]) + + with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in msks_list]) + +# check if the path is valid +def check_list(rlms_list, rimgs_list, rmsks_list): + lms_list, imgs_list, msks_list = [], [], [] + for i in range(len(rlms_list)): + flag = 'false' + lm_path = rlms_list[i] + im_path = rimgs_list[i] + msk_path = rmsks_list[i] + if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): + flag = 'true' + lms_list.append(rlms_list[i]) + imgs_list.append(rimgs_list[i]) + msks_list.append(rmsks_list[i]) + print(i, rlms_list[i], flag) + return lms_list, imgs_list, msks_list diff --git a/chat_anything/sad_talker/face3d/util/html.py b/chat_anything/sad_talker/face3d/util/html.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/chat_anything/sad_talker/face3d/util/load_mats.py b/chat_anything/sad_talker/face3d/util/load_mats.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a6fcc71de1d7dad8b0f81c67dc1c213764ff0b --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/load_mats.py @@ -0,0 +1,120 @@ +"""This script is to load 3D face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from PIL import Image +from scipy.io import loadmat, savemat +from array import array +import os.path as osp + +# load expression basis +def LoadExpBasis(bfm_folder='BFM'): + n_vertex = 53215 + Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin, 1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin, 3*n_vertex) + expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) + Expbin.close() + + expPC = np.array(expPC) + expPC = np.reshape(expPC, [exp_dim[0], -1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) + + return expPC, expEV + + +# transfer original BFM09 to our face model +def transferBFM09(bfm_folder='BFM'): + print('Transfer BFM09 to BFM_model_front......') + original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC, expEV = LoadExpBasis(bfm_folder) + + # transfer BFM09 to our face model + + idBase = shapePC*np.reshape(shapeEV, [-1, 199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:, :80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV, [-1, 79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:, :64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV, [-1, 199]) + texBase = texBase[:, :80] # use only first 80 basis + + # our face model is cropped along face landmarks and contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) + index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) + + index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) + index_shape = index_shape['trimIndex'].astype( + np.int32) - 1 # starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + idBase = np.reshape(idBase, [-1, 3, 80]) + idBase = idBase[index_shape, :, :] + idBase = np.reshape(idBase, [-1, 80]) + + texBase = np.reshape(texBase, [-1, 3, 80]) + texBase = texBase[index_shape, :, :] + texBase = np.reshape(texBase, [-1, 80]) + + exBase = np.reshape(exBase, [-1, 3, 64]) + exBase = exBase[index_exp, :, :] + exBase = np.reshape(exBase, [-1, 64]) + + meanshape = np.reshape(shapeMU, [-1, 3])/1e5 + meanshape = meanshape[index_shape, :] + meanshape = np.reshape(meanshape, [1, -1]) + + meantex = np.reshape(texMU, [-1, 3]) + meantex = meantex[index_shape, :] + meantex = np.reshape(meantex, [1, -1]) + + # other info contains triangles, region used for computing photometric loss, + # region used for skin texture regularization, and 68 landmarks index etc. + other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) + frontmask2_idx = other_info['frontmask2_idx'] + skinmask = other_info['skinmask'] + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + tri_mask2 = other_info['tri_mask2'] + + # save our face model + savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, + 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) + + +# load landmarks for standard face, which is used for image preprocessing +def load_lm3d(bfm_folder): + + Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( + Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) + Lm3D = Lm3D[[1, 2, 0, 3, 4], :] + + return Lm3D + + +if __name__ == '__main__': + transferBFM09() \ No newline at end of file diff --git a/chat_anything/sad_talker/face3d/util/my_awing_arch.py b/chat_anything/sad_talker/face3d/util/my_awing_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..cfab6236b7080a1b00259a40a09baa06ce3c3110 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/my_awing_arch.py @@ -0,0 +1,378 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def calculate_points(heatmaps): + # change heatmaps to landmarks + B, N, H, W = heatmaps.shape + HW = H * W + BN_range = np.arange(B * N) + + heatline = heatmaps.reshape(B, N, HW) + indexes = np.argmax(heatline, axis=2) + + preds = np.stack((indexes % W, indexes // W), axis=2) + preds = preds.astype(np.float32, copy=False) + + inr = indexes.ravel() + + heatline = heatline.reshape(B * N, HW) + x_up = heatline[BN_range, inr + 1] + x_down = heatline[BN_range, inr - 1] + # y_up = heatline[BN_range, inr + W] + + if any((inr + W) >= 4096): + y_up = heatline[BN_range, 4095] + else: + y_up = heatline[BN_range, inr + W] + if any((inr - W) <= 0): + y_down = heatline[BN_range, 0] + else: + y_down = heatline[BN_range, inr - W] + + think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1)) + think_diff *= .25 + + preds += think_diff.reshape(B, N, 2) + preds += .5 + return preds + + +class AddCoordsTh(nn.Module): + + def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): + super(AddCoordsTh, self).__init__() + self.x_dim = x_dim + self.y_dim = y_dim + self.with_r = with_r + self.with_boundary = with_boundary + + def forward(self, input_tensor, heatmap=None): + """ + input_tensor: (batch, c, x_dim, y_dim) + """ + batch_size_tensor = input_tensor.shape[0] + + xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device) + xx_ones = xx_ones.unsqueeze(-1) + + xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) + xx_range = xx_range.unsqueeze(1) + + xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) + xx_channel = xx_channel.unsqueeze(-1) + + yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device) + yy_ones = yy_ones.unsqueeze(1) + + yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) + yy_range = yy_range.unsqueeze(-1) + + yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) + yy_channel = yy_channel.unsqueeze(-1) + + xx_channel = xx_channel.permute(0, 3, 2, 1) + yy_channel = yy_channel.permute(0, 3, 2, 1) + + xx_channel = xx_channel / (self.x_dim - 1) + yy_channel = yy_channel / (self.y_dim - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) + yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) + + if self.with_boundary and heatmap is not None: + boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) + + zero_tensor = torch.zeros_like(xx_channel) + xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor) + yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor) + if self.with_boundary and heatmap is not None: + xx_boundary_channel = xx_boundary_channel.to(input_tensor.device) + yy_boundary_channel = yy_boundary_channel.to(input_tensor.device) + ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) + rr = rr / torch.max(rr) + ret = torch.cat([ret, rr], dim=1) + + if self.with_boundary and heatmap is not None: + ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1) + return ret + + +class CoordConvTh(nn.Module): + """CoordConv layer as in the paper.""" + + def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs): + super(CoordConvTh, self).__init__() + self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary) + in_channels += 2 + if with_r: + in_channels += 1 + if with_boundary and not first_one: + in_channels += 2 + self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) + + def forward(self, input_tensor, heatmap=None): + ret = self.addcoords(input_tensor, heatmap) + last_channel = ret[:, -2:, :, :] + ret = self.conv(ret) + return ret, last_channel + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1): + '3x3 convolution with padding' + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + # self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + # self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + + out = self.conv2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ConvBlock(nn.Module): + + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class HourGlass(nn.Module): + + def __init__(self, num_modules, depth, num_features, first_one=False): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.coordconv = CoordConvTh( + x_dim=64, + y_dim=64, + with_r=True, + with_boundary=True, + in_channels=256, + first_one=first_one, + out_channels=256, + kernel_size=1, + stride=1, + padding=0) + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(256, 256)) + + self.add_module('b2_' + str(level), ConvBlock(256, 256)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) + + self.add_module('b3_' + str(level), ConvBlock(256, 256)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x, heatmap): + x, last_channel = self.coordconv(x, heatmap) + return self._forward(self.depth, x), last_channel + + +class FAN(nn.Module): + + def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'): + super(FAN, self).__init__() + self.device = device + self.num_modules = num_modules + self.gray_scale = gray_scale + self.end_relu = end_relu + self.num_landmarks = num_landmarks + + # Base part + if self.gray_scale: + self.conv1 = CoordConvTh( + x_dim=256, + y_dim=256, + with_r=True, + with_boundary=False, + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3) + else: + self.conv1 = CoordConvTh( + x_dim=256, + y_dim=256, + with_r=True, + with_boundary=False, + in_channels=3, + out_channels=64, + kernel_size=7, + stride=2, + padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + if hg_module == 0: + first_one = True + else: + first_one = False + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), + nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x, _ = self.conv1(x) + x = F.relu(self.bn1(x), True) + # x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + boundary_channels = [] + tmp_out = None + for i in range(self.num_modules): + hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + if self.end_relu: + tmp_out = F.relu(tmp_out) # HACK: Added relu + outputs.append(tmp_out) + boundary_channels.append(boundary_channel) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs, boundary_channels + + def get_landmarks(self, img): + H, W, _ = img.shape + offset = W / 64, H / 64, 0, 0 + + img = cv2.resize(img, (256, 256)) + inp = img[..., ::-1] + inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float() + inp = inp.to(self.device) + inp.div_(255.0).unsqueeze_(0) + + outputs, _ = self.forward(inp) + out = outputs[-1][:, :-1, :, :] + heatmaps = out.detach().cpu().numpy() + + pred = calculate_points(heatmaps).reshape(-1, 2) + + pred *= offset[:2] + pred += offset[-2:] + + return pred diff --git a/chat_anything/sad_talker/face3d/util/nvdiffrast.py b/chat_anything/sad_talker/face3d/util/nvdiffrast.py new file mode 100644 index 0000000000000000000000000000000000000000..f3245859c650afbfe841a66b74cddefaf28820d9 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/nvdiffrast.py @@ -0,0 +1,126 @@ +"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch + Attention, antialiasing step is missing in current version. +""" +import pytorch3d.ops +import torch +import torch.nn.functional as F +import kornia +from kornia.geometry.camera import pixel2cam +import numpy as np +from typing import List +from scipy.io import loadmat +from torch import nn + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + DirectionalLights, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, +) + +# def ndc_projection(x=0.1, n=1.0, f=50.0): +# return np.array([[n/x, 0, 0, 0], +# [ 0, n/-x, 0, 0], +# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], +# [ 0, 0, -1, 0]]).astype(np.float32) + +class MeshRenderer(nn.Module): + def __init__(self, + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224): + super(MeshRenderer, self).__init__() + + # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear + # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( + # torch.diag(torch.tensor([1., -1, -1, 1]))) + self.rasterize_size = rasterize_size + self.fov = rasterize_fov + self.znear = znear + self.zfar = zfar + + self.rasterizer = None + + def forward(self, vertex, tri, feat=None): + """ + Return: + mask -- torch.tensor, size (B, 1, H, W) + depth -- torch.tensor, size (B, 1, H, W) + features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None + + Parameters: + vertex -- torch.tensor, size (B, N, 3) + tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles + feat(optional) -- torch.tensor, size (B, N ,C), features + """ + device = vertex.device + rsize = int(self.rasterize_size) + # ndc_proj = self.ndc_proj.to(device) + # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v + if vertex.shape[-1] == 3: + vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) + vertex[..., 0] = -vertex[..., 0] + + + # vertex_ndc = vertex @ ndc_proj.t() + if self.rasterizer is None: + self.rasterizer = MeshRasterizer() + print("create rasterizer on device cuda:%d"%device.index) + + # ranges = None + # if isinstance(tri, List) or len(tri.shape) == 3: + # vum = vertex_ndc.shape[1] + # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) + # fstartidx = torch.cumsum(fnum, dim=0) - fnum + # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() + # for i in range(tri.shape[0]): + # tri[i] = tri[i] + i*vum + # vertex_ndc = torch.cat(vertex_ndc, dim=0) + # tri = torch.cat(tri, dim=0) + + # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] + tri = tri.type(torch.int32).contiguous() + + # rasterize + cameras = FoVPerspectiveCameras( + device=device, + fov=self.fov, + znear=self.znear, + zfar=self.zfar, + ) + + raster_settings = RasterizationSettings( + image_size=rsize + ) + + # print(vertex.shape, tri.shape) + mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) + + fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) + rast_out = fragments.pix_to_face.squeeze(-1) + depth = fragments.zbuf + + # render depth + depth = depth.permute(0, 3, 1, 2) + mask = (rast_out > 0).float().unsqueeze(1) + depth = mask * depth + + + image = None + if feat is not None: + attributes = feat.reshape(-1,3)[mesh.faces_packed()] + image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, + fragments.bary_coords, + attributes) + # print(image.shape) + image = image.squeeze(-2).permute(0, 3, 1, 2) + image = mask * image + + return mask, depth, image + diff --git a/chat_anything/sad_talker/face3d/util/preprocess.py b/chat_anything/sad_talker/face3d/util/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..470e240bdb19d4e05d8c744d44e8ce0585b5ff9b --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/preprocess.py @@ -0,0 +1,103 @@ +"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from scipy.io import loadmat +from PIL import Image +import cv2 +import os +from skimage import transform as trans +import torch +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# calculating least square problem for image alignment +def POS(xp, x): + npts = xp.shape[1] + + A = np.zeros([2*npts, 8]) + + A[0:2*npts-1:2, 0:3] = x.transpose() + A[0:2*npts-1:2, 3] = 1 + + A[1:2*npts:2, 4:7] = x.transpose() + A[1:2*npts:2, 7] = 1 + + b = np.reshape(xp.transpose(), [2*npts, 1]) + + k, _, _, _ = np.linalg.lstsq(A, b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.concatenate([sTx, sTy], axis=0) # bug + + return t, s + +# resize and crop images for face reconstruction +def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): + w0, h0 = img.size + w = (w0*s).astype(np.int32) + h = (h0*s).astype(np.int32) + left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) + below = up + target_size + + img = img.resize((w, h), resample=Image.BICUBIC) + img = img.crop((left, up, right, below)) + + if mask is not None: + mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.crop((left, up, right, below)) + + lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - + t[1] + h0/2], axis=1)*s + lm = lm - np.reshape( + np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) + + return img, lm, mask + +# utils for face reconstruction +def extract_5p(lm): + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( + lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) + lm5p = lm5p[[1, 2, 0, 3, 4], :] + return lm5p + +# utils for face reconstruction +def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): + """ + Return: + transparams --numpy.array (raw_W, raw_H, scale, tx, ty) + img_new --PIL.Image (target_size, target_size, 3) + lm_new --numpy.array (68, 2), y direction is opposite to v direction + mask_new --PIL.Image (target_size, target_size) + + Parameters: + img --PIL.Image (raw_H, raw_W, 3) + lm --numpy.array (68, 2), y direction is opposite to v direction + lm3D --numpy.array (5, 3) + mask --PIL.Image (raw_H, raw_W, 3) + """ + + w0, h0 = img.size + if lm.shape[0] != 5: + lm5p = extract_5p(lm) + else: + lm5p = lm + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face + t, s = POS(lm5p.transpose(), lm3D.transpose()) + s = rescale_factor/s + + # processing the image + img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) + trans_params = np.array([w0, h0, s, t[0], t[1]]) + + return trans_params, img_new, lm_new, mask_new diff --git a/chat_anything/sad_talker/face3d/util/skin_mask.py b/chat_anything/sad_talker/face3d/util/skin_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a74e4c3b40d13b0258b83a12f56321a85bb179 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/skin_mask.py @@ -0,0 +1,125 @@ +"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch +""" + +import math +import numpy as np +import os +import cv2 + +class GMM: + def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + + self.factor = [0]*num + for i in range(self.num): + self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 + + def likelihood(self, data): + assert(data.shape[1] == self.dim) + N = data.shape[0] + lh = np.zeros(N) + + for i in range(self.num): + data_ = data - self.mu[i] + + tmp = np.matmul(data_,self.cov_inv[i]) * data_ + tmp = np.sum(tmp,axis=1) + power = -0.5 * tmp + + p = np.array([math.exp(power[j]) for j in range(N)]) + p = p/self.factor[i] + lh += p*self.w[i] + + return lh + + +def _rgb2ycbcr(rgb): + m = np.array([[65.481, 128.553, 24.966], + [-37.797, -74.203, 112], + [112, -93.786, -18.214]]) + shape = rgb.shape + rgb = rgb.reshape((shape[0] * shape[1], 3)) + ycbcr = np.dot(rgb, m.transpose() / 255.) + ycbcr[:, 0] += 16. + ycbcr[:, 1:] += 128. + return ycbcr.reshape(shape) + + +def _bgr2ycbcr(bgr): + rgb = bgr[..., ::-1] + return _rgb2ycbcr(rgb) + + +gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] +gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] +gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] + +gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) + +gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] +gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] +gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] + +gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) + +prior_skin = 0.8 +prior_nonskin = 1 - prior_skin + + +# calculate skin attention mask +def skinmask(imbgr): + im = _bgr2ycbcr(imbgr) + + data = im.reshape((-1,3)) + + lh_skin = gmm_skin.likelihood(data) + lh_nonskin = gmm_nonskin.likelihood(data) + + tmp1 = prior_skin * lh_skin + tmp2 = prior_nonskin * lh_nonskin + post_skin = tmp1 / (tmp1+tmp2) # posterior probability + + post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + + post_skin = np.round(post_skin*255) + post_skin = post_skin.astype(np.uint8) + post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + + return post_skin + + +def get_skin_mask(img_path): + print('generating skin masks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + save_path = os.path.join(img_path, 'mask') + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + img = cv2.imread(full_image_name).astype(np.float32) + skin_img = skinmask(img) + cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/chat_anything/sad_talker/face3d/util/test_mean_face.txt b/chat_anything/sad_talker/face3d/util/test_mean_face.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/test_mean_face.txt @@ -0,0 +1,136 @@ +-5.228591537475585938e+01 +2.078247070312500000e-01 +-5.064269638061523438e+01 +-1.315765380859375000e+01 +-4.952939224243164062e+01 +-2.592591094970703125e+01 +-4.793047332763671875e+01 +-3.832135772705078125e+01 +-4.512159729003906250e+01 +-5.059623336791992188e+01 +-3.917720794677734375e+01 +-6.043736648559570312e+01 +-2.929953765869140625e+01 +-6.861183166503906250e+01 +-1.719801330566406250e+01 +-7.572736358642578125e+01 +-1.961936950683593750e+00 +-7.862001037597656250e+01 +1.467941284179687500e+01 +-7.607844543457031250e+01 +2.744073486328125000e+01 +-6.915261840820312500e+01 +3.855677795410156250e+01 +-5.950350570678710938e+01 +4.478240966796875000e+01 +-4.867547225952148438e+01 +4.714337158203125000e+01 +-3.800830078125000000e+01 +4.940315246582031250e+01 +-2.496297454833984375e+01 +5.117234802246093750e+01 +-1.241538238525390625e+01 +5.190507507324218750e+01 +8.244247436523437500e-01 +-4.150688934326171875e+01 +2.386329650878906250e+01 +-3.570307159423828125e+01 +3.017010498046875000e+01 +-2.790358734130859375e+01 +3.212951660156250000e+01 +-1.941773223876953125e+01 +3.156523132324218750e+01 +-1.138106536865234375e+01 +2.841992187500000000e+01 +5.993263244628906250e+00 +2.895182800292968750e+01 +1.343590545654296875e+01 +3.189880371093750000e+01 +2.203153991699218750e+01 +3.302221679687500000e+01 +2.992478942871093750e+01 +3.099150085449218750e+01 +3.628388977050781250e+01 +2.765748596191406250e+01 +-1.933914184570312500e+00 +1.405374145507812500e+01 +-2.153038024902343750e+00 +5.772636413574218750e+00 +-2.270050048828125000e+00 +-2.121643066406250000e+00 +-2.218330383300781250e+00 +-1.068978118896484375e+01 +-1.187252044677734375e+01 +-1.997912597656250000e+01 +-6.879402160644531250e+00 +-2.143579864501953125e+01 +-1.227821350097656250e+00 +-2.193494415283203125e+01 +4.623237609863281250e+00 +-2.152721405029296875e+01 +9.721397399902343750e+00 +-1.953671264648437500e+01 +-3.648714447021484375e+01 +9.811126708984375000e+00 +-3.130242919921875000e+01 +1.422447967529296875e+01 +-2.212834930419921875e+01 +1.493019866943359375e+01 +-1.500880432128906250e+01 +1.073588562011718750e+01 +-2.095037078857421875e+01 +9.054298400878906250e+00 +-3.050099182128906250e+01 +8.704177856445312500e+00 +1.173237609863281250e+01 +1.054329681396484375e+01 +1.856353759765625000e+01 +1.535009765625000000e+01 +2.893331909179687500e+01 +1.451992797851562500e+01 +3.452944946289062500e+01 +1.065280151367187500e+01 +2.875990295410156250e+01 +8.654792785644531250e+00 +1.942100524902343750e+01 +9.422447204589843750e+00 +-2.204488372802734375e+01 +-3.983994293212890625e+01 +-1.324458312988281250e+01 +-3.467377471923828125e+01 +-6.749649047851562500e+00 +-3.092894744873046875e+01 +-9.183349609375000000e-01 +-3.196458435058593750e+01 +4.220649719238281250e+00 +-3.090406036376953125e+01 +1.089889526367187500e+01 +-3.497008514404296875e+01 +1.874589538574218750e+01 +-4.065438079833984375e+01 +1.124106597900390625e+01 +-4.438417816162109375e+01 +5.181709289550781250e+00 +-4.649170684814453125e+01 +-1.158607482910156250e+00 +-4.680406951904296875e+01 +-7.918922424316406250e+00 +-4.671575164794921875e+01 +-1.452505493164062500e+01 +-4.416526031494140625e+01 +-2.005007171630859375e+01 +-3.997841644287109375e+01 +-1.054919433593750000e+01 +-3.849683380126953125e+01 +-1.051826477050781250e+00 +-3.794863128662109375e+01 +6.412681579589843750e+00 +-3.804645538330078125e+01 +1.627674865722656250e+01 +-4.039697265625000000e+01 +6.373878479003906250e+00 +-4.087213897705078125e+01 +-8.551712036132812500e-01 +-4.157129669189453125e+01 +-1.014953613281250000e+01 +-4.128469085693359375e+01 diff --git a/chat_anything/sad_talker/face3d/util/util.py b/chat_anything/sad_talker/face3d/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0d689ca138fc0fbf5bec794511ea0f9e638f9ea9 --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/util.py @@ -0,0 +1,208 @@ +"""This script contains basic utilities for Deep3DFaceRecon_pytorch +""" +from __future__ import print_function +import numpy as np +import torch +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +import torchvision + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + +def genvalconf(train_opt, **kwargs): + conf = Namespace(**vars(train_opt)) + attr_dict = train_opt.__dict__ + for key, value in attr_dict.items(): + if 'val' in key and key.split('_')[0] in attr_dict: + setattr(conf, key.split('_')[0], value) + + for key in kwargs: + setattr(conf, key, kwargs[key]) + + return conf + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array, range(0, 1) + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def correct_resize_label(t, size): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i, :1] + one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) + one_np = one_np[:, :, 0] + one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) + resized_t = torch.from_numpy(np.array(one_image)).long() + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i + 1] + one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + +def draw_landmarks(img, landmark, color='r', step=2): + """ + Return: + img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) + + + Parameters: + img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) + landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction + color -- str, 'r' or 'b' (red or blue) + """ + if color =='r': + c = np.array([255., 0, 0]) + else: + c = np.array([0, 0, 255.]) + + _, H, W, _ = img.shape + img, landmark = img.copy(), landmark.copy() + landmark[..., 1] = H - 1 - landmark[..., 1] + landmark = np.round(landmark).astype(np.int32) + for i in range(landmark.shape[1]): + x, y = landmark[:, i, 0], landmark[:, i, 1] + for j in range(-step, step): + for k in range(-step, step): + u = np.clip(x + j, 0, W - 1) + v = np.clip(y + k, 0, H - 1) + for m in range(landmark.shape[0]): + img[m, v[m], u[m]] = c + return img diff --git a/chat_anything/sad_talker/face3d/util/visualizer.py b/chat_anything/sad_talker/face3d/util/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4023a6d4086acba9bc88e079f625194d324d7c9e --- /dev/null +++ b/chat_anything/sad_talker/face3d/util/visualizer.py @@ -0,0 +1,227 @@ +"""This script defines the visualizer for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from torch.utils.tensorboard import SummaryWriter + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.use_html = opt.isTrain and not opt.no_html + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) + self.win_size = opt.display_winsize + self.name = opt.name + self.saved = False + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + + def display_current_results(self, visuals, total_iters, epoch, save_result): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + for label, image in visuals.items(): + self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, total_iters, losses): + # G_loss_collection = {} + # D_loss_collection = {} + # for name, value in losses.items(): + # if 'G' in name or 'NCE' in name or 'idt' in name: + # G_loss_collection[name] = value + # else: + # D_loss_collection[name] = value + # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) + # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) + for name, value in losses.items(): + self.writer.add_scalar(name, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message + + +class MyVisualizer: + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the optio + self.name = opt.name + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') + + if opt.phase != 'test': + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + + def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, + add_image=True): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + dataset (str) - - 'train' or 'val' or 'test' + """ + # if (not add_image) and (not save_results): return + + for label, image in visuals.items(): + for i in range(image.shape[0]): + image_numpy = util.tensor2im(image[i]) + if add_image: + self.writer.add_image(label + '%s_%02d'%(dataset, i + count), + image_numpy, total_iters, dataformats='HWC') + + if save_results: + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + if name is not None: + img_path = os.path.join(save_path, '%s.png' % name) + else: + img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) + util.save_image(image_numpy, img_path) + + + def plot_current_losses(self, total_iters, losses, dataset='train'): + for name, value in losses.items(): + self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( + dataset, epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/chat_anything/sad_talker/face3d/visualize.py b/chat_anything/sad_talker/face3d/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..b7294d991ab2fbc5ad4415d2fe991199ad238ef7 --- /dev/null +++ b/chat_anything/sad_talker/face3d/visualize.py @@ -0,0 +1,48 @@ +# check the sync of 3dmm feature and the audio +import cv2 +import numpy as np +from chat_anything.sad_talker.face3d.models.bfm import ParametricFaceModel +from chat_anything.sad_talker.face3d.models.facerecon_model import FaceReconModel +import torch +import subprocess, platform +import scipy.io as scio +from tqdm import tqdm + +# draft +def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): + + coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] + + coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] + + coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 + + coeff_full[:, 80:144] = coeff_pred[:, 0:64] + coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation + coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation + + tmp_video_path = '/tmp/face3dtmp.mp4' + + facemodel = FaceReconModel(args) + + video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) + + for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): + cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) + + facemodel.forward(cur_coeff_full, device) + + predicted_landmark = facemodel.pred_lm # TODO. + predicted_landmark = predicted_landmark.cpu().numpy().squeeze() + + rendered_img = facemodel.pred_face + rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) + out_img = rendered_img[:, :, :3].astype(np.uint8) + + video.write(np.uint8(out_img[:,:,::-1])) + + video.release() + + command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) + subprocess.call(command, shell=platform.system() != 'Windows') + diff --git a/chat_anything/sad_talker/facerender/animate.py b/chat_anything/sad_talker/facerender/animate.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc5308a5da5aff443f9e34244511192159ec859 --- /dev/null +++ b/chat_anything/sad_talker/facerender/animate.py @@ -0,0 +1,269 @@ +import os +import cv2 +import yaml +import numpy as np +import warnings +from skimage import img_as_ubyte +import safetensors +import safetensors.torch +warnings.filterwarnings('ignore') + + +import imageio +import torch +import torchvision + + +from chat_anything.sad_talker.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from chat_anything.sad_talker.facerender.modules.mapping import MappingNet +from chat_anything.sad_talker.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator +from chat_anything.sad_talker.facerender.modules.make_animation import make_animation + +from pydub import AudioSegment +from chat_anything.sad_talker.utils.face_enhancer import enhancer_generator_with_len, enhancer_list +from chat_anything.sad_talker.utils.paste_pic import paste_pic +from chat_anything.sad_talker.utils.videoio import save_video_with_watermark + +try: + import webui # in webui + in_webui = True +except: + in_webui = False + +class AnimateFromCoeff(): + + def __init__(self, sadtalker_path, device): + + with open(sadtalker_path['facerender_yaml']) as f: + config = yaml.safe_load(f) + + generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) + he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) + mapping = MappingNet(**config['model_params']['mapping_params']) + + generator.to(device) + kp_extractor.to(device) + he_estimator.to(device) + mapping.to(device) + for param in generator.parameters(): + param.requires_grad = False + for param in kp_extractor.parameters(): + param.requires_grad = False + for param in he_estimator.parameters(): + param.requires_grad = False + for param in mapping.parameters(): + param.requires_grad = False + + if sadtalker_path is not None: + if 'checkpoint' in sadtalker_path: # use safe tensor + self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None) + else: + self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + if sadtalker_path['mappingnet_checkpoint'] is not None: + self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping) + else: + raise AttributeError("Checkpoint should be specified for video head pose estimator.") + + devices = list(range(torch.cuda.device_count())) + device = 0 + generator = torch.nn.DataParallel(generator, device_ids=devices, output_device=device) + kp_extractor = torch.nn.DataParallel(kp_extractor, device_ids=devices, output_device=device) + he_estimator = torch.nn.DataParallel(he_estimator, device_ids=devices, output_device=device) + mapping = torch.nn.DataParallel(mapping, device_ids=devices, output_device=device) + + self.kp_extractor = kp_extractor + self.generator = generator + self.he_estimator = he_estimator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.he_estimator.eval() + self.mapping.eval() + + self.device = device + + def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + + def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None, + optimizer_mapping=None, optimizer_discriminator=None, device='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if mapping is not None: + mapping.load_state_dict(checkpoint['mapping']) + if discriminator is not None: + discriminator.load_state_dict(checkpoint['discriminator']) + if optimizer_mapping is not None: + optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) + if optimizer_discriminator is not None: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + + return checkpoint['epoch'] + + def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256): + + source_image=x['source_image'].type(torch.FloatTensor) + source_semantics=x['source_semantics'].type(torch.FloatTensor) + target_semantics=x['target_semantics_list'].type(torch.FloatTensor) + source_image=source_image.to(self.device) + source_semantics=source_semantics.to(self.device) + target_semantics=target_semantics.to(self.device) + if 'yaw_c_seq' in x: + yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor) + yaw_c_seq = x['yaw_c_seq'].to(self.device) + else: + yaw_c_seq = None + if 'pitch_c_seq' in x: + pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor) + pitch_c_seq = x['pitch_c_seq'].to(self.device) + else: + pitch_c_seq = None + if 'roll_c_seq' in x: + roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) + roll_c_seq = x['roll_c_seq'].to(self.device) + else: + roll_c_seq = None + + frame_num = x['frame_num'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, self.he_estimator, self.mapping, + yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True) + + predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:]) + predictions_video = predictions_video[:frame_num] + + video = [] + for idx in range(predictions_video.shape[0]): + image = predictions_video[idx] + image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32) + video.append(image) + result = img_as_ubyte(video) + + ### the generated video is 256x256, so we keep the aspect ratio, + original_size = crop_info[0] + if original_size: + result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ] + + video_name = x['video_name'] + '.mp4' + path = os.path.join(video_save_dir, 'temp_'+video_name) + imageio.mimsave(path, result, fps=float(25)) + # try: + # enhanced_images_gen_with_len = enhancer_generator_with_len(path, method='gfpgan', bg_upsampler='realesrgan') + # imageio.mimsave(path, enhanced_images_gen_with_len, fps=float(25)) + # except: + # enhanced_images_gen_with_len = enhancer_list(path, method='gfpgan', bg_upsampler='realesrgan') + # imageio.mimsave(path, enhanced_images_gen_with_len, fps=float(25)) + + av_path = os.path.join(video_save_dir, video_name) + audio_path = x['audio_path'] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + new_audio_path = os.path.join(video_save_dir, audio_name+'.wav') + start_time = 0 + # cog will not keep the .mp3 filename + sound = AudioSegment.from_file(audio_path) + frames = frame_num + end_time = start_time + frames*1/25*1000 + word1=sound.set_frame_rate(16000) + word = word1[start_time:end_time] + word.export(new_audio_path, format="wav") + print("============================") + print("saved moving images:", path) + # print(f'The generated video is named {video_save_dir}/{video_name}') + if 'full' in preprocess.lower(): + # only add watermark to the full image. + video_name_full = x['video_name'] + '_full.mp4' + full_video_path = os.path.join(video_save_dir, video_name_full) + paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False) + print(f"full video:{full_video_path}") + return_path = full_video_path + else: + save_video_with_watermark(path, new_audio_path, av_path, watermark= False) + return_path = av_path + print(f"crop video:{return_path}") + print("the given temp file:", return_path) + return return_path + # #### paste back then enhancers + # if enhancer: + # video_name_enhancer = x['video_name'] + '_enhanced.mp4' + # enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer) + # av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) + # return_path = av_path_enhancer + + # try: + # enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer) + # imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) + # except: + # enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer) + # imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25)) + + # save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False) + # print(f'The generated video is named {video_save_dir}/{video_name_enhancer}') + # # os.remove(enhanced_path) + + # # os.remove(path) + # os.remove(new_audio_path) + + # return return_path + diff --git a/chat_anything/sad_talker/facerender/modules/dense_motion.py b/chat_anything/sad_talker/facerender/modules/dense_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7214d077e80fb9e0056c74c8031e6b77f3590d --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/dense_motion.py @@ -0,0 +1,121 @@ +from torch import nn +import torch.nn.functional as F +import torch +from chat_anything.sad_talker.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian + +from chat_anything.sad_talker.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, + estimate_occlusion_map=False): + super(DenseMotionNetwork, self).__init__() + # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks) + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) + + self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) + + self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) + self.norm = BatchNorm3d(compress, affine=True) + + if estimate_occlusion_map: + # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3) + self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3) + else: + self.occlusion = None + + self.num_kp = num_kp + + + def create_sparse_motions(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type()) + identity_grid = identity_grid.view(1, 1, d, h, w, 3) + coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3) + + # if 'jacobian' in kp_driving: + if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None: + jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + + driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3) + + #adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3 + + # sparse_motions = driving_to_source + + return sparse_motions + + def create_deformed_feature(self, feature, sparse_motions): + bs, _, d, h, w = feature.shape + feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w) + feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!! + sparse_deformed = F.grid_sample(feature_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w) + return sparse_deformed + + def create_heatmap_representations(self, feature, kp_driving, kp_source): + spatial_size = feature.shape[3:] + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) + heatmap = gaussian_driving - gaussian_source + + # adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type()) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + return heatmap + + def forward(self, feature, kp_driving, kp_source): + bs, _, d, h, w = feature.shape + + feature = self.compress(feature) + feature = self.norm(feature) + feature = F.relu(feature) + + out_dict = dict() + sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) + deformed_feature = self.create_deformed_feature(feature, sparse_motion) + + heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) + + input_ = torch.cat([heatmap, deformed_feature], dim=2) + input_ = input_.view(bs, -1, d, h, w) + + # input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w) + + prediction = self.hourglass(input_) + + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w) + + zeros_mask = torch.zeros_like(mask) + mask = torch.where(mask < 1e-3, zeros_mask, mask) + + sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w) + deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) + deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3) + + out_dict['deformation'] = deformation + + if self.occlusion: + bs, c, d, h, w = prediction.shape + prediction = prediction.view(bs, -1, h, w) + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/chat_anything/sad_talker/facerender/modules/discriminator.py b/chat_anything/sad_talker/facerender/modules/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..d4459b07cb075c9f9d345f9b3dffc02cd859313b --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/discriminator.py @@ -0,0 +1,90 @@ +from torch import nn +import torch.nn.functional as F +from facerender.modules.util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + def forward(self, x): + feature_maps = [] + out = x + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key]) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/chat_anything/sad_talker/facerender/modules/generator.py b/chat_anything/sad_talker/facerender/modules/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..90d1084d532f9c67d431a5bc0c059f43b2fb9b0f --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/generator.py @@ -0,0 +1,255 @@ +import torch +from torch import nn +import torch.nn.functional as F +from chat_anything.sad_talker.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock +from chat_anything.sad_talker.facerender.modules.dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator follows NVIDIA architecture. + """ + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.resblocks_2d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1)) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i))) + out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image + + # Decoding part + out = self.resblocks_2d(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict + + +class SPADEDecoder(nn.Module): + def __init__(self): + super().__init__() + ic = 256 + oc = 64 + norm_G = 'spadespectralinstance' + label_nc = 256 + + self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1) + self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc) + self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc) + self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc) + self.conv_img = nn.Conv2d(oc, 3, 3, padding=1) + self.up = nn.Upsample(scale_factor=2) + + def forward(self, feature): + seg = feature + x = self.fc(feature) + x = self.G_middle_0(x, seg) + x = self.G_middle_1(x, seg) + x = self.G_middle_2(x, seg) + x = self.G_middle_3(x, seg) + x = self.G_middle_4(x, seg) + x = self.G_middle_5(x, seg) + x = self.up(x) + x = self.up_0(x, seg) # 256, 128, 128 + x = self.up(x) + x = self.up_1(x, seg) # 64, 256, 256 + + x = self.conv_img(F.leaky_relu(x, 2e-1)) + # x = torch.tanh(x) + x = F.sigmoid(x) + + return x + + +class OcclusionAwareSPADEGenerator(nn.Module): + + def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth, + num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareSPADEGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) + + self.reshape_channel = reshape_channel + self.reshape_depth = reshape_depth + + self.resblocks_3d = torch.nn.Sequential() + for i in range(num_resblocks): + self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) + + out_features = block_expansion * (2 ** (num_down_blocks)) + self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True) + self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1) + + self.estimate_occlusion_map = estimate_occlusion_map + self.image_channel = image_channel + + self.decoder = SPADEDecoder() + + def deform_input(self, inp, deformation): + _, d_old, h_old, w_old, _ = deformation.shape + _, _, d, h, w = inp.shape + if d_old != d or h_old != h or w_old != w: + deformation = deformation.permute(0, 4, 1, 2, 3) + deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear') + deformation = deformation.permute(0, 2, 3, 4, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + out = self.second(out) + bs, c, h, w = out.shape + # print(out.shape) + feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) + feature_3d = self.resblocks_3d(feature_3d) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + + # import pdb; pdb.set_trace() + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(feature_3d, deformation) + + bs, c, d, h, w = out.shape + out = out.view(bs, c*d, h, w) + out = self.third(out) + out = self.fourth(out) + + # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + # Decoding part + out = self.decoder(out) + + output_dict["prediction"] = out + + return output_dict \ No newline at end of file diff --git a/chat_anything/sad_talker/facerender/modules/keypoint_detector.py b/chat_anything/sad_talker/facerender/modules/keypoint_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc09ebda5e02baf7b08595e6f5c666cb1a375a3 --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/keypoint_detector.py @@ -0,0 +1,179 @@ +from torch import nn +import torch +import torch.nn.functional as F + +from chat_anything.sad_talker.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from chat_anything.sad_talker.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck + + +class KPDetector(nn.Module): + """ + Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False): + super(KPDetector, self).__init__() + + self.predictor = KPHourglass(block_expansion, in_features=image_channel, + max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks) + + # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3) + self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3) + self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1) + ''' + initial as: + [[1 0 0] + [0 1 0] + [0 0 1]] + ''' + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + value = (heatmap * grid).sum(dim=(2, 3, 4)) + kp = {'value': value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2], + final_shape[3], final_shape[4]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3) + out['jacobian'] = jacobian + + return out + + +class HEEstimator(nn.Module): + """ + Estimating head pose and expression. + """ + + def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True): + super(HEEstimator, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2) + self.norm1 = BatchNorm2d(block_expansion, affine=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1) + self.norm2 = BatchNorm2d(256, affine=True) + + self.block1 = nn.Sequential() + for i in range(3): + self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1)) + + self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1) + self.norm3 = BatchNorm2d(512, affine=True) + self.block2 = ResBottleneck(in_features=512, stride=2) + + self.block3 = nn.Sequential() + for i in range(3): + self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1)) + + self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1) + self.norm4 = BatchNorm2d(1024, affine=True) + self.block4 = ResBottleneck(in_features=1024, stride=2) + + self.block5 = nn.Sequential() + for i in range(5): + self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1)) + + self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1) + self.norm5 = BatchNorm2d(2048, affine=True) + self.block6 = ResBottleneck(in_features=2048, stride=2) + + self.block7 = nn.Sequential() + for i in range(2): + self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1)) + + self.fc_roll = nn.Linear(2048, num_bins) + self.fc_pitch = nn.Linear(2048, num_bins) + self.fc_yaw = nn.Linear(2048, num_bins) + + self.fc_t = nn.Linear(2048, 3) + + self.fc_exp = nn.Linear(2048, 3*num_kp) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.maxpool(out) + + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + + out = self.block1(out) + + out = self.conv3(out) + out = self.norm3(out) + out = F.relu(out) + out = self.block2(out) + + out = self.block3(out) + + out = self.conv4(out) + out = self.norm4(out) + out = F.relu(out) + out = self.block4(out) + + out = self.block5(out) + + out = self.conv5(out) + out = self.norm5(out) + out = F.relu(out) + out = self.block6(out) + + out = self.block7(out) + + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.shape[0], -1) + + yaw = self.fc_roll(out) + pitch = self.fc_pitch(out) + roll = self.fc_yaw(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} + diff --git a/chat_anything/sad_talker/facerender/modules/make_animation.py b/chat_anything/sad_talker/facerender/modules/make_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..3360c53501a064f35d7db21a5361f89aa9658b42 --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/make_animation.py @@ -0,0 +1,170 @@ +from scipy.spatial import ConvexHull +import torch +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm + +def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False, + use_relative_movement=False, use_relative_jacobian=False): + if adapt_movement_scale: + source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume + driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume + adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area) + else: + adapt_movement_scale = 1 + + kp_new = {k: v for k, v in kp_driving.items()} + + if use_relative_movement: + kp_value_diff = (kp_driving['value'] - kp_driving_initial['value']) + kp_value_diff *= adapt_movement_scale + kp_new['value'] = kp_value_diff + kp_source['value'] + + if use_relative_jacobian: + jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian'])) + kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian']) + + return kp_new + +def headpose_pred_to_degree(pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + return degree + +def get_rotation_matrix(yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), + torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), + torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) + + return rot_mat + +def keypoint_transformation(kp_canonical, he, wo_exp=False): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll= he['yaw'], he['pitch'], he['roll'] + yaw = headpose_pred_to_degree(yaw) + pitch = headpose_pred_to_degree(pitch) + roll = headpose_pred_to_degree(roll) + + if 'yaw_in' in he: + yaw = he['yaw_in'] + if 'pitch_in' in he: + pitch = he['pitch_in'] + if 'roll_in' in he: + roll = he['roll_in'] + + rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + t, exp = he['t'], he['exp'] + if wo_exp: + exp = exp*0 + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + # keypoint translation + t[:, 0] = t[:, 0]*0 + t[:, 2] = t[:, 2]*0 + t = t.unsqueeze(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + + +def make_animation(source_image, source_semantics, target_semantics, + generator, kp_detector, he_estimator, mapping, + yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None, + use_exp=True, use_half=False): + with torch.no_grad(): + predictions = [] + + kp_canonical = kp_detector(source_image) + he_source = mapping(source_semantics) + kp_source = keypoint_transformation(kp_canonical, he_source) + + for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'): + # still check the dimension + # print(target_semantics.shape, source_semantics.shape) + target_semantics_frame = target_semantics[:, frame_idx] + he_driving = mapping(target_semantics_frame) + if yaw_c_seq is not None: + he_driving['yaw_in'] = yaw_c_seq[:, frame_idx] + if pitch_c_seq is not None: + he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] + if roll_c_seq is not None: + he_driving['roll_in'] = roll_c_seq[:, frame_idx] + + kp_driving = keypoint_transformation(kp_canonical, he_driving) + + kp_norm = kp_driving + out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm) + ''' + source_image_new = out['prediction'].squeeze(1) + kp_canonical_new = kp_detector(source_image_new) + he_source_new = he_estimator(source_image_new) + kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True) + kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True) + out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new) + ''' + predictions.append(out['prediction']) + predictions_ts = torch.stack(predictions, dim=1) + return predictions_ts + +class AnimateModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, generator, kp_extractor, mapping): + super(AnimateModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.mapping = mapping + + self.kp_extractor.eval() + self.generator.eval() + self.mapping.eval() + + def forward(self, x): + + source_image = x['source_image'] + source_semantics = x['source_semantics'] + target_semantics = x['target_semantics'] + yaw_c_seq = x['yaw_c_seq'] + pitch_c_seq = x['pitch_c_seq'] + roll_c_seq = x['roll_c_seq'] + + predictions_video = make_animation(source_image, source_semantics, target_semantics, + self.generator, self.kp_extractor, + self.mapping, use_exp = True, + yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq) + + return predictions_video \ No newline at end of file diff --git a/chat_anything/sad_talker/facerender/modules/mapping.py b/chat_anything/sad_talker/facerender/modules/mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..0e3a1c2d1770996080c08e9daafb346f05d7bcdd --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/mapping.py @@ -0,0 +1,47 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MappingNet(nn.Module): + def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): + super( MappingNet, self).__init__() + + self.layer = layer + nonlinearity = nn.LeakyReLU(0.1) + + self.first = nn.Sequential( + torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) + + for i in range(layer): + net = nn.Sequential(nonlinearity, + torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) + setattr(self, 'encoder' + str(i), net) + + self.pooling = nn.AdaptiveAvgPool1d(1) + self.output_nc = descriptor_nc + + self.fc_roll = nn.Linear(descriptor_nc, num_bins) + self.fc_pitch = nn.Linear(descriptor_nc, num_bins) + self.fc_yaw = nn.Linear(descriptor_nc, num_bins) + self.fc_t = nn.Linear(descriptor_nc, 3) + self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) + + def forward(self, input_3dmm): + out = self.first(input_3dmm) + for i in range(self.layer): + model = getattr(self, 'encoder' + str(i)) + out = model(out) + out[:,:,3:-3] + out = self.pooling(out) + out = out.view(out.shape[0], -1) + #print('out:', out.shape) + + yaw = self.fc_yaw(out) + pitch = self.fc_pitch(out) + roll = self.fc_roll(out) + t = self.fc_t(out) + exp = self.fc_exp(out) + + return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} \ No newline at end of file diff --git a/chat_anything/sad_talker/facerender/modules/util.py b/chat_anything/sad_talker/facerender/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..7581ae895c6fe8a4132c4dcb9c06725829564350 --- /dev/null +++ b/chat_anything/sad_talker/facerender/modules/util.py @@ -0,0 +1,564 @@ +from torch import nn + +import torch.nn.functional as F +import torch + +from chat_anything.sad_talker.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d +from chat_anything.sad_talker.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d + +import torch.nn.utils.spectral_norm as spectral_norm + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp['value'] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + +def make_coordinate_grid_2d(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +def make_coordinate_grid(spatial_size, type): + d, h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + z = torch.arange(d).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + z = (2 * (z / (d - 1)) - 1) + + yy = y.view(1, -1, 1).repeat(d, 1, w) + xx = x.view(1, 1, -1).repeat(d, h, 1) + zz = z.view(-1, 1, 1).repeat(1, h, w) + + meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3) + + return meshed + + +class ResBottleneck(nn.Module): + def __init__(self, in_features, stride): + super(ResBottleneck, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1) + self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1) + self.norm1 = BatchNorm2d(in_features//4, affine=True) + self.norm2 = BatchNorm2d(in_features//4, affine=True) + self.norm3 = BatchNorm2d(in_features, affine=True) + + self.stride = stride + if self.stride != 1: + self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride) + self.norm4 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.conv1(x) + out = self.norm1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv3(out) + out = self.norm3(out) + if self.stride != 1: + x = self.skip(x) + x = self.norm4(x) + out += x + out = F.relu(out) + return out + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class ResBlock3d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock3d, self).__init__() + self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm3d(in_features, affine=True) + self.norm2 = BatchNorm3d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class UpBlock3d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock3d, self).__init__() + + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + + def forward(self, x): + # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + out = F.interpolate(x, scale_factor=(1, 2, 2)) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class DownBlock3d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock3d, self).__init__() + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups, stride=(1, 2, 2)) + ''' + self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm3d(out_features, affine=True) + self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + if lrelu: + self.ac = nn.LeakyReLU() + else: + self.ac = nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.ac(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + # self.out_filters = block_expansion + self.out_filters = block_expansion + in_features + + self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1) + self.norm = BatchNorm3d(self.out_filters, affine=True) + + def forward(self, x): + out = x.pop() + # for up_block in self.up_blocks[:-1]: + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + # out = self.up_blocks[-1](out) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class KPHourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256): + super(KPHourglass, self).__init__() + + self.down_blocks = nn.Sequential() + for i in range(num_blocks): + self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + + in_filters = min(max_features, block_expansion * (2 ** num_blocks)) + self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1) + + self.up_blocks = nn.Sequential() + for i in range(num_blocks): + in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i))) + out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1))) + self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.reshape_depth = reshape_depth + self.out_filters = out_filters + + def forward(self, x): + out = self.down_blocks(x) + out = self.conv(out) + bs, c, h, w = out.shape + out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w) + out = self.up_blocks(out) + + return out + + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out + + +class SPADE(nn.Module): + def __init__(self, norm_nc, label_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1), + nn.ReLU()) + self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1) + + def forward(self, x, segmap): + normalized = self.param_free_norm(x) + segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') + actv = self.mlp_shared(segmap) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + out = normalized * (1 + gamma) + beta + return out + + +class SPADEResnetBlock(nn.Module): + def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1): + super().__init__() + # Attributes + self.learned_shortcut = (fin != fout) + fmiddle = min(fin, fout) + self.use_se = use_se + # create conv layers + self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) + self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) + # apply spectral norm if specified + if 'spectral' in norm_G: + self.conv_0 = spectral_norm(self.conv_0) + self.conv_1 = spectral_norm(self.conv_1) + if self.learned_shortcut: + self.conv_s = spectral_norm(self.conv_s) + # define normalization layers + self.norm_0 = SPADE(fin, label_nc) + self.norm_1 = SPADE(fmiddle, label_nc) + if self.learned_shortcut: + self.norm_s = SPADE(fin, label_nc) + + def forward(self, x, seg1): + x_s = self.shortcut(x, seg1) + dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) + dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) + out = x_s + dx + return out + + def shortcut(self, x, seg1): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg1)) + else: + x_s = x + return x_s + + def actvn(self, x): + return F.leaky_relu(x, 2e-1) + +class audio2image(nn.Module): + def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params): + super().__init__() + # Attributes + self.generator = generator + self.kp_extractor = kp_extractor + self.he_estimator_video = he_estimator_video + self.he_estimator_audio = he_estimator_audio + self.train_params = train_params + + def headpose_pred_to_degree(self, pred): + device = pred.device + idx_tensor = [idx for idx in range(66)] + idx_tensor = torch.FloatTensor(idx_tensor).to(device) + pred = F.softmax(pred) + degree = torch.sum(pred*idx_tensor, 1) * 3 - 99 + + return degree + + def get_rotation_matrix(self, yaw, pitch, roll): + yaw = yaw / 180 * 3.14 + pitch = pitch / 180 * 3.14 + roll = roll / 180 * 3.14 + + roll = roll.unsqueeze(1) + pitch = pitch.unsqueeze(1) + yaw = yaw.unsqueeze(1) + + roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), + torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), + torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) + roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) + + pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), + torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), + -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) + pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) + + yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), + torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), + torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) + yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) + + rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) + + return rot_mat + + def keypoint_transformation(self, kp_canonical, he): + kp = kp_canonical['value'] # (bs, k, 3) + yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] + t, exp = he['t'], he['exp'] + + yaw = self.headpose_pred_to_degree(yaw) + pitch = self.headpose_pred_to_degree(pitch) + roll = self.headpose_pred_to_degree(roll) + + rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3) + + # keypoint rotation + kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) + + + + # keypoint translation + t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) + kp_t = kp_rotated + t + + # add expression deviation + exp = exp.view(exp.shape[0], -1, 3) + kp_transformed = kp_t + exp + + return {'value': kp_transformed} + + def forward(self, source_image, target_audio): + pose_source = self.he_estimator_video(source_image) + pose_generated = self.he_estimator_audio(target_audio) + kp_canonical = self.kp_extractor(source_image) + kp_source = self.keypoint_transformation(kp_canonical, pose_source) + kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated) + generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated) + return generated \ No newline at end of file diff --git a/chat_anything/sad_talker/facerender/sync_batchnorm/__init__.py b/chat_anything/sad_talker/facerender/sync_batchnorm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/chat_anything/sad_talker/facerender/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/chat_anything/sad_talker/facerender/sync_batchnorm/batchnorm.py b/chat_anything/sad_talker/facerender/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4e763f0366dffa10320116413f8c7181a8aeb1 --- /dev/null +++ b/chat_anything/sad_talker/facerender/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/chat_anything/sad_talker/facerender/sync_batchnorm/comm.py b/chat_anything/sad_talker/facerender/sync_batchnorm/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..922f8c4a3adaa9b32fdcaef09583be03b0d7eb2b --- /dev/null +++ b/chat_anything/sad_talker/facerender/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/chat_anything/sad_talker/facerender/sync_batchnorm/replicate.py b/chat_anything/sad_talker/facerender/sync_batchnorm/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/chat_anything/sad_talker/facerender/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/chat_anything/sad_talker/facerender/sync_batchnorm/unittest.py b/chat_anything/sad_talker/facerender/sync_batchnorm/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/chat_anything/sad_talker/facerender/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/chat_anything/sad_talker/generate_batch.py b/chat_anything/sad_talker/generate_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..33944d6565b6963c70e2934b5270caf86fe1f52d --- /dev/null +++ b/chat_anything/sad_talker/generate_batch.py @@ -0,0 +1,120 @@ +import os + +from tqdm import tqdm +import torch +import numpy as np +import random +import scipy.io as scio +import chat_anything.sad_talker.utils.audio as audio + +def crop_pad_audio(wav, audio_length): + if len(wav) > audio_length: + wav = wav[:audio_length] + elif len(wav) < audio_length: + wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) + return wav + +def parse_audio_length(audio_length, sr, fps): + bit_per_frames = sr / fps + + num_frames = int(audio_length / bit_per_frames) + audio_length = int(num_frames * bit_per_frames) + + return audio_length, num_frames + +def generate_blink_seq(num_frames): + ratio = np.zeros((num_frames,1)) + frame_id = 0 + while frame_id in range(num_frames): + start = 80 + if frame_id+start+9<=num_frames - 1: + ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] + frame_id = frame_id+start+9 + else: + break + return ratio + +def generate_blink_seq_randomly(num_frames): + ratio = np.zeros((num_frames,1)) + if num_frames<=20: + return ratio + frame_id = 0 + while frame_id in range(num_frames): + start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) + if frame_id+start+5<=num_frames - 1: + ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] + frame_id = frame_id+start+5 + else: + break + return ratio + +def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): + + syncnet_mel_step_size = 16 + fps = 25 + + pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] + audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] + + + if idlemode: + num_frames = int(length_of_audio * 25) + indiv_mels = np.zeros((num_frames, 80, 16)) + else: + wav = audio.load_wav(audio_path, 16000) + wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) + wav = crop_pad_audio(wav, wav_length) + orig_mel = audio.melspectrogram(wav).T + spec = orig_mel.copy() # nframes 80 + indiv_mels = [] + + for i in tqdm(range(num_frames), 'mel:'): + start_frame_num = i-2 + start_idx = int(80. * (start_frame_num / float(fps))) + end_idx = start_idx + syncnet_mel_step_size + seq = list(range(start_idx, end_idx)) + seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] + m = spec[seq, :] + indiv_mels.append(m.T) + indiv_mels = np.asarray(indiv_mels) # T 80 16 + + ratio = generate_blink_seq_randomly(num_frames) # T + source_semantics_path = first_coeff_path + source_semantics_dict = scio.loadmat(source_semantics_path) + ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 + ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) + + if ref_eyeblink_coeff_path is not None: + ratio[:num_frames] = 0 + refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) + refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] + refeyeblink_num_frames = refeyeblink_coeff.shape[0] + if refeyeblink_num_frames frame_num: + new_degree_list = new_degree_list[:frame_num] + elif len(new_degree_list) < frame_num: + for _ in range(frame_num-len(new_degree_list)): + new_degree_list.append(new_degree_list[-1]) + print(len(new_degree_list)) + print(frame_num) + + remainder = frame_num%batch_size + if remainder!=0: + for _ in range(batch_size-remainder): + new_degree_list.append(new_degree_list[-1]) + new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) + return new_degree_np + \ No newline at end of file diff --git a/chat_anything/sad_talker/sad_talker.py b/chat_anything/sad_talker/sad_talker.py new file mode 100644 index 0000000000000000000000000000000000000000..956a2bd03283aeefd232a5e0b83a938a95e66fdb --- /dev/null +++ b/chat_anything/sad_talker/sad_talker.py @@ -0,0 +1,175 @@ +import torch, uuid +import os, sys, shutil , pdb +from chat_anything.sad_talker.utils.preprocess import CropAndExtract +from chat_anything.sad_talker.test_audio2coeff import Audio2Coeff +from chat_anything.sad_talker.facerender.animate import AnimateFromCoeff +from chat_anything.sad_talker.generate_batch import get_data +from chat_anything.sad_talker.generate_facerender_batch import get_facerender_data + +from chat_anything.sad_talker.utils.init_path import init_path + +from pydub import AudioSegment + +def mp3_to_wav(mp3_filename,wav_filename,frame_rate): + mp3_file = AudioSegment.from_file(file=mp3_filename) + mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav") + + +class SadTalker(): + + def __init__(self, checkpoint_path='checkpoints', config_path='chat_anything/sad_talker/config', lazy_load=False, \ + preprocess='crop', size=256): + + if torch.cuda.is_available() : + device = "cuda" + else: + device = "cpu" + + self.device = device + + os.environ['TORCH_HOME']= checkpoint_path + + self.checkpoint_path = checkpoint_path + self.config_path = config_path + # script_path = os.path.abspath(__file__) + + # root_dir = os.path.dirname(script_path) + # print(root_dir) + # pdb.set_trace() + + # Model init + print('=============debugging here===============') + # pdb.set_trace() + self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess) + print(self.sadtalker_paths) + + self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device) + self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device) + self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device) + + + def test(self, source_image, driven_audio, preprocess='crop', + still_mode=False, use_enhancer=False, batch_size=4, size=256, + pose_style = 0, exp_scale=1.0, + use_ref_video = False, + ref_video = None, + ref_info = None, + use_idle_mode = False, + length_of_audio = 0, use_blink=True, uid=None): + + result_dir=os.path.join('./tmp/',uid) + time_tag = str(uuid.uuid4()) + save_dir = os.path.join(result_dir, time_tag) + os.makedirs(save_dir, exist_ok=True) + + input_dir = os.path.join(save_dir, 'input') + os.makedirs(input_dir, exist_ok=True) + + print(source_image) + pic_path = os.path.join(input_dir, os.path.basename(source_image)) + print("move--------------------------------------------") + shutil.copy(source_image, input_dir) + + if driven_audio is not None and os.path.isfile(driven_audio): + audio_path = os.path.join(input_dir, os.path.basename(driven_audio)) + + #### mp3 to wav + if '.mp3' in audio_path: + mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000) + audio_path = audio_path.replace('.mp3', '.wav') + else: + shutil.move(driven_audio, input_dir) + + elif use_idle_mode: + audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path + from pydub import AudioSegment + one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds + one_sec_segment.export(audio_path, format="wav") + else: + print(use_ref_video, ref_info) + assert use_ref_video == True and ref_info == 'all' + + if use_ref_video and ref_info == 'all': # full ref mode + ref_video_videoname = os.path.basename(ref_video) + audio_path = os.path.join(save_dir, ref_video_videoname+'.wav') + print('new audiopath:',audio_path) + # if ref_video contains audio, set the audio from ref_video. + cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path) + os.system(cmd) + + os.makedirs(save_dir, exist_ok=True) + + #crop image and extract 3dmm from image + first_frame_dir = os.path.join(save_dir, 'first_frame_dir') + os.makedirs(first_frame_dir, exist_ok=True) + first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size) + + if first_coeff_path is None: + raise AttributeError("No face is detected") + + # TODO: Preprocess Image, for the init coefficient + if use_ref_video: + print('using ref video for genreation') + ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0] + ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname) + os.makedirs(ref_video_frame_dir, exist_ok=True) + print('3DMM Extraction for the reference video providing pose') + ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False) + else: + ref_video_coeff_path = None + + if use_ref_video: + if ref_info == 'pose': + ref_pose_coeff_path = ref_video_coeff_path + ref_eyeblink_coeff_path = None + elif ref_info == 'blink': + ref_pose_coeff_path = None + ref_eyeblink_coeff_path = ref_video_coeff_path + elif ref_info == 'pose+blink': + ref_pose_coeff_path = ref_video_coeff_path + ref_eyeblink_coeff_path = ref_video_coeff_path + elif ref_info == 'all': + ref_pose_coeff_path = None + ref_eyeblink_coeff_path = None + else: + raise('error in refinfo') + else: + ref_pose_coeff_path = None + ref_eyeblink_coeff_path = None + + #audio2ceoff + # TODO: generate sequence coefficient from audio & init_coefficient + if use_ref_video and ref_info == 'all': + coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) + else: + batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio? + coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path) + + #coeff2video + # TODO: generate the video from produced sequence coefficient. Produce the data for renderer. + data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale) + # print("here is a point-----------------------------------------") + # # print(data) + # print(save_dir) + # print(pic_path) + # TODO: render! + return_value = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size) + # print(save_dir) + # print(pic_path) + video_name = data['video_name'] + # print(f'The generated video is named {video_name} in {save_dir}') + + # del self.preprocess_model + # del self.audio_to_coeff + # del self.animate_from_coeff + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + import gc; gc.collect() + + # return return_path + return return_value + + diff --git a/chat_anything/sad_talker/test_audio2coeff.py b/chat_anything/sad_talker/test_audio2coeff.py new file mode 100644 index 0000000000000000000000000000000000000000..dfde39825c8df0445c083987a5f8e22717d11ac8 --- /dev/null +++ b/chat_anything/sad_talker/test_audio2coeff.py @@ -0,0 +1,123 @@ +import os +import torch +import numpy as np +from scipy.io import savemat, loadmat +from yacs.config import CfgNode as CN +from scipy.signal import savgol_filter + +import safetensors +import safetensors.torch + +from chat_anything.sad_talker.audio2pose_models.audio2pose import Audio2Pose +from chat_anything.sad_talker.audio2exp_models.networks import SimpleWrapperV2 +from chat_anything.sad_talker.audio2exp_models.audio2exp import Audio2Exp +from chat_anything.sad_talker.utils.safetensor_helper import load_x_from_safetensor + +def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if model is not None: + model.load_state_dict(checkpoint['model']) + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + return checkpoint['epoch'] + +class Audio2Coeff(): + + def __init__(self, sadtalker_path, device): + #load config + fcfg_pose = open(sadtalker_path['audio2pose_yaml_path']) + cfg_pose = CN.load_cfg(fcfg_pose) + cfg_pose.freeze() + fcfg_exp = open(sadtalker_path['audio2exp_yaml_path']) + cfg_exp = CN.load_cfg(fcfg_exp) + cfg_exp.freeze() + + # load audio2pose_model + self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device) + self.audio2pose_model = self.audio2pose_model.to(device) + self.audio2pose_model.eval() + for param in self.audio2pose_model.parameters(): + param.requires_grad = False + + try: + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose')) + else: + load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device) + except: + raise Exception("Failed in loading audio2pose_checkpoint") + + # load audio2exp_model + netG = SimpleWrapperV2() + netG = netG.to(device) + for param in netG.parameters(): + netG.requires_grad = False + netG.eval() + try: + if sadtalker_path['use_safetensor']: + checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint']) + netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp')) + else: + load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device) + except: + raise Exception("Failed in loading audio2exp_checkpoint") + self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False) + self.audio2exp_model = self.audio2exp_model.to(device) + for param in self.audio2exp_model.parameters(): + param.requires_grad = False + self.audio2exp_model.eval() + + self.device = device + + def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None): + + with torch.no_grad(): + #test + results_dict_exp= self.audio2exp_model.test(batch) + exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64 + + #for class_id in range(1): + #class_id = 0#(i+10)%45 + #class_id = random.randint(0,46) #46 styles can be selected + batch['class'] = torch.LongTensor([pose_style]).to(self.device) + results_dict_pose = self.audio2pose_model.test(batch) + pose_pred = results_dict_pose['pose_pred'] #bs T 6 + + pose_len = pose_pred.shape[1] + if pose_len<13: + pose_len = int((pose_len-1)/2)*2+1 + pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device) + else: + pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) + + coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70 + + coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() + + if ref_pose_coeff_path is not None: + coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path) + + savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])), + {'coeff_3dmm': coeffs_pred_numpy}) + + return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])) + + def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path): + num_frames = coeffs_pred_numpy.shape[0] + refpose_coeff_dict = loadmat(ref_pose_coeff_path) + refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70] + refpose_num_frames = refpose_coeff.shape[0] + if refpose_num_frames= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/chat_anything/sad_talker/utils/croper.py b/chat_anything/sad_talker/utils/croper.py new file mode 100644 index 0000000000000000000000000000000000000000..22a323d544368cf0573ad084c90fc8fc40bc2b73 --- /dev/null +++ b/chat_anything/sad_talker/utils/croper.py @@ -0,0 +1,144 @@ +import os +import cv2 +import time +import glob +import argparse +import scipy +import numpy as np +from PIL import Image +import torch +from tqdm import tqdm +from itertools import cycle + +from chat_anything.sad_talker.face3d.extract_kp_videos_safe import KeypointExtractor +from facexlib.alignment import landmark_98_to_68 + +import numpy as np +from PIL import Image + +class Preprocesser: + def __init__(self, device='cuda'): + self.predictor = KeypointExtractor(device) + + def get_landmark(self, img_np): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + with torch.no_grad(): + dets = self.predictor.det_net.detect_faces(img_np, 0.97) + + if len(dets) == 0: + return None + det = dets[0] + + img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :] + lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0] + + #### keypoints to the original location + lm[:,0] += int(det[0]) + lm[:,1] += int(det[1]) + + return lm + + def align_face(self, img, lm, output_size=1024): + """ + :param filepath: str + :return: PIL Image + """ + lm_chin = lm[0: 17] # left-right + lm_eyebrow_left = lm[17: 22] # left-right + lm_eyebrow_right = lm[22: 27] # left-right + lm_nose = lm[27: 31] # top-down + lm_nostrils = lm[31: 36] # top-down + lm_eye_left = lm[36: 42] # left-clockwise + lm_eye_right = lm[42: 48] # left-clockwise + lm_mouth_outer = lm[48: 60] # left-clockwise + lm_mouth_inner = lm[60: 68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference + x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化 + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度 + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点 + qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍 + + # Shrink. + # 如果计算出的四边形太大了,就按比例缩小它 + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + else: + rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1])))) + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + # img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + # if enable_padding and max(pad) > border - 4: + # pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + # img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # h, w, _ = img.shape + # y, x, _ = np.ogrid[:h, :w, :1] + # mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), + # 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) + # blur = qsize * 0.02 + # img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + # img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + # img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + # quad += pad[:2] + + # Transform. + quad = (quad + 0.5).flatten() + lx = max(min(quad[0], quad[2]), 0) + ly = max(min(quad[1], quad[7]), 0) + rx = min(max(quad[4], quad[6]), img.size[0]) + ry = min(max(quad[3], quad[5]), img.size[0]) + + # Save aligned image. + return rsize, crop, [lx, ly, rx, ry] + + def crop(self, img_np_list, still=False, xsize=512): # first frame for all video + img_np = img_np_list[0] + lm = self.get_landmark(img_np) + + if lm is None: + raise 'can not detect the landmark from source image' + rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + for _i in range(len(img_np_list)): + _inp = img_np_list[_i] + _inp = cv2.resize(_inp, (rsize[0], rsize[1])) + _inp = _inp[cly:cry, clx:crx] + if not still: + _inp = _inp[ly:ry, lx:rx] + img_np_list[_i] = _inp + return img_np_list, crop, quad + diff --git a/chat_anything/sad_talker/utils/face_enhancer.py b/chat_anything/sad_talker/utils/face_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..12b926940f54657e9c9bf0a03db85ca221f1baa1 --- /dev/null +++ b/chat_anything/sad_talker/utils/face_enhancer.py @@ -0,0 +1,134 @@ +import os +import torch + +from gfpgan import GFPGANer + +from tqdm import tqdm + +from chat_anything.sad_talker.utils.videoio import load_video_to_cv2 + +import cv2 +import math +import concurrent.futures +class GeneratorWithLen(object): + """ From https://stackoverflow.com/a/7460929 """ + + def __init__(self, gen, length): + self.gen = gen + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen + +def process_frame(img,restorer): + # restore faces and background if necessary + cropped_faces, restored_faces, r_img = restorer.module.enhance( + img, + has_aligned=False, + only_center_face=False, + paste_back=True) + + r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) + return r_img + +def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): + gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) + return list(gen) + +def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): + """ Provide a generator with a __len__ method so that it can passed to functions that + call len()""" + + if os.path.isfile(images): # handle video to images + # TODO: Create a generator version of load_video_to_cv2 + images = load_video_to_cv2(images) + + gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) + gen_with_len = GeneratorWithLen(gen, len(images)) + return gen_with_len + +def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): + """ Provide a generator function so that all of the enhanced images don't need + to be stored in memory at the same time. This can save tons of RAM compared to + the enhancer function. """ + + print('face enhancer....') + if not isinstance(images, list) and os.path.isfile(images): # handle video to images + images = load_video_to_cv2(images) + + # ------------------------ set up GFPGAN restorer ------------------------ + if method == 'gfpgan': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.4' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' + elif method == 'RestoreFormer': + arch = 'RestoreFormer' + channel_multiplier = 2 + model_name = 'RestoreFormer' + url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' + elif method == 'codeformer': # TODO: + arch = 'CodeFormer' + channel_multiplier = 2 + model_name = 'CodeFormer' + url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + else: + raise ValueError(f'Wrong model version {method}.') + + + # ------------------------ set up background upsampler ------------------------ + if bg_upsampler == 'realesrgan': + if not torch.cuda.is_available(): # CPU + import warnings + warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' + 'If you really want to use it, please modify the corresponding codes.') + bg_upsampler = None + else: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + bg_upsampler = RealESRGANer( + scale=2, + model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + model=model, + tile=400, + tile_pad=10, + pre_pad=0, + half=True) # need to set False in CPU mode + else: + bg_upsampler = None + + # determine model paths + model_path = os.path.join('gfpgan/weights', model_name + '.pth') + + + if not os.path.isfile(model_path): + model_path = os.path.join('checkpoints', model_name + '.pth') + + if not os.path.isfile(model_path): + # download pre-trained models from url + model_path = url + + restorer = GFPGANer( + model_path=model_path, + upscale=2, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=bg_upsampler) + + + for idx in tqdm(range(len(images)), 'Face Enhancer:'): + + img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) + + cropped_faces, restored_faces, r_img = restorer.enhance( + img, + has_aligned=False, + only_center_face=False, + paste_back=True) + + r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) + yield r_img diff --git a/chat_anything/sad_talker/utils/hparams.py b/chat_anything/sad_talker/utils/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..743c5c7d5a5a9e686f1ccd6fb3c2fb5cb382d62b --- /dev/null +++ b/chat_anything/sad_talker/utils/hparams.py @@ -0,0 +1,160 @@ +from glob import glob +import os + +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=20, + checkpoint_interval=3000, + eval_interval=3000, + writer_interval=300, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=1000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + + +# Default hyperparameters +hparamsdebug = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=2, + initial_learning_rate=1e-3, + nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=0, + checkpoint_interval=10000, + eval_interval=10, + writer_interval=5, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] + return "Hyperparameters:\n" + "\n".join(hp) diff --git a/chat_anything/sad_talker/utils/init_path.py b/chat_anything/sad_talker/utils/init_path.py new file mode 100644 index 0000000000000000000000000000000000000000..5f38d11907bd0dc789992062ce7f02d8876c638f --- /dev/null +++ b/chat_anything/sad_talker/utils/init_path.py @@ -0,0 +1,47 @@ +import os +import glob + +def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): + + if old_version: + #### load all the checkpoint of `pth` + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + use_safetensor = False + elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): + print('using safetensor as default') + sadtalker_paths = { + "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), + } + use_safetensor = True + else: + print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") + use_safetensor = False + + sadtalker_paths = { + 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), + 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), + 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), + 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), + 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') + } + + sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' + sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') + sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') + sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') + + if 'full' in preprocess: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') + else: + sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') + sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') + + return sadtalker_paths \ No newline at end of file diff --git a/chat_anything/sad_talker/utils/model2safetensor.py b/chat_anything/sad_talker/utils/model2safetensor.py new file mode 100644 index 0000000000000000000000000000000000000000..4751f59ed8b65f3ba9b87483499e5ad57d2c1798 --- /dev/null +++ b/chat_anything/sad_talker/utils/model2safetensor.py @@ -0,0 +1,141 @@ +import torch +import yaml +import os + +import safetensors +from safetensors.torch import save_file +from yacs.config import CfgNode as CN +import sys + +sys.path.append('/apdcephfs/private_shadowcun/SadTalker') + +from chat_anything.sad_talker.face3d.models import networks + +from chat_anything.sad_talker.facerender.modules.keypoint_detector import HEEstimator, KPDetector +from chat_anything.sad_talker.facerender.modules.mapping import MappingNet +from chat_anything.sad_talker.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator + +from chat_anything.sad_talker.audio2pose_models.audio2pose import Audio2Pose +from chat_anything.sad_talker.audio2exp_models.networks import SimpleWrapperV2 +from chat_anything.sad_talker.test_audio2coeff import load_cpk + +size = 256 +############ face vid2vid +config_path = os.path.join('src', 'config', 'facerender.yaml') +current_root_path = '.' + +path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') +net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') +checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') +net_recon.load_state_dict(checkpoint['net_recon']) + +with open(config_path) as f: + config = yaml.safe_load(f) + +generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], + **config['model_params']['common_params']) +kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], + **config['model_params']['common_params']) +he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], + **config['model_params']['common_params']) +mapping = MappingNet(**config['model_params']['mapping_params']) + +def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, + kp_detector=None, he_estimator=None, optimizer_generator=None, + optimizer_discriminator=None, optimizer_kp_detector=None, + optimizer_he_estimator=None, device="cpu"): + + checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) + if generator is not None: + generator.load_state_dict(checkpoint['generator']) + if kp_detector is not None: + kp_detector.load_state_dict(checkpoint['kp_detector']) + if he_estimator is not None: + he_estimator.load_state_dict(checkpoint['he_estimator']) + if discriminator is not None: + try: + discriminator.load_state_dict(checkpoint['discriminator']) + except: + print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') + if optimizer_generator is not None: + optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) + if optimizer_discriminator is not None: + try: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) + except RuntimeError as e: + print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') + if optimizer_kp_detector is not None: + optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) + if optimizer_he_estimator is not None: + optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) + + return checkpoint['epoch'] + + +def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, + kp_detector=None, he_estimator=None, + device="cpu"): + + checkpoint = safetensors.torch.load_file(checkpoint_path) + + if generator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'generator' in k: + x_generator[k.replace('generator.', '')] = v + generator.load_state_dict(x_generator) + if kp_detector is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'kp_extractor' in k: + x_generator[k.replace('kp_extractor.', '')] = v + kp_detector.load_state_dict(x_generator) + if he_estimator is not None: + x_generator = {} + for k,v in checkpoint.items(): + if 'he_estimator' in k: + x_generator[k.replace('he_estimator.', '')] = v + he_estimator.load_state_dict(x_generator) + + return None + +free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' +load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) + +wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') + +audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') +audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') + +audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') +audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') + +fcfg_pose = open(audio2pose_yaml_path) +cfg_pose = CN.load_cfg(fcfg_pose) +cfg_pose.freeze() +audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) +audio2pose_model.eval() +load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') + +# load audio2exp_model +netG = SimpleWrapperV2() +netG.eval() +load_cpk(audio2exp_checkpoint, model=netG, device='cpu') + +class SadTalker(torch.nn.Module): + def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): + super(SadTalker, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.audio2exp = netG + self.audio2pose = audio2pose + self.face_3drecon = face_3drecon + + +model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) + +# here, we want to convert it to safetensor +save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") + +### test +load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) \ No newline at end of file diff --git a/chat_anything/sad_talker/utils/paste_pic.py b/chat_anything/sad_talker/utils/paste_pic.py new file mode 100644 index 0000000000000000000000000000000000000000..caa9d85da712a5f851c38f42a4c8e5bd77e37b25 --- /dev/null +++ b/chat_anything/sad_talker/utils/paste_pic.py @@ -0,0 +1,70 @@ +import cv2, os +import numpy as np +from tqdm import tqdm +import uuid + +from chat_anything.sad_talker.utils.videoio import save_video_with_watermark + +def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): + + if not os.path.isfile(pic_path): + raise ValueError('pic_path must be a valid path to video/image file') + elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: + # loader for first frame + full_img = cv2.imread(pic_path) + else: + # loader for videos + video_stream = cv2.VideoCapture(pic_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + break + full_img = frame + frame_h = full_img.shape[0] + frame_w = full_img.shape[1] + + video_stream = cv2.VideoCapture(video_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + print(f"fps:{fps}") + crop_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + crop_frames.append(frame) + + if len(crop_info) != 3: + print("you didn't crop the image") + return + else: + r_w, r_h = crop_info[0] + clx, cly, crx, cry = crop_info[1] + lx, ly, rx, ry = crop_info[2] + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + + if extended_crop: + oy1, oy2, ox1, ox2 = cly, cry, clx, crx + else: + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + dirname = os.path.dirname(new_audio_path) + tmp_path = os.path.join(dirname, str(uuid.uuid4())+'.mp4') + out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'XVID'), fps, (frame_w, frame_h)) + for crop_frame in tqdm(crop_frames, 'seamlessClone:'): + p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) + + mask = 255*np.ones(p.shape, p.dtype) + location = ((ox1+ox2) // 2, (oy1+oy2) // 2) + gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) + out_tmp.write(gen_img) + + out_tmp.release() + print("paste_pic==================tmp_path") + print(tmp_path) + save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) diff --git a/chat_anything/sad_talker/utils/preprocess.py b/chat_anything/sad_talker/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..22a8a342cc3533b3c9c63095af8d9029904aa4fc --- /dev/null +++ b/chat_anything/sad_talker/utils/preprocess.py @@ -0,0 +1,175 @@ +import numpy as np +import cv2, os, sys, torch +from tqdm import tqdm +from PIL import Image +import time +# 3dmm extraction +import safetensors +import safetensors.torch +from chat_anything.sad_talker.face3d.util.preprocess import align_img +from chat_anything.sad_talker.face3d.util.load_mats import load_lm3d +from chat_anything.sad_talker.face3d.models import networks + +from scipy.io import loadmat, savemat +from chat_anything.sad_talker.utils.croper import Preprocesser + + +import warnings + +from chat_anything.sad_talker.utils.safetensor_helper import load_x_from_safetensor +warnings.filterwarnings("ignore") + +def split_coeff(coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + + +class CropAndExtract(): + def __init__(self, sadtalker_path, device): + + self.propress = Preprocesser(device) + self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device) + + if sadtalker_path['use_safetensor']: + checkpoint = safetensors.torch.load_file(sadtalker_path['checkpoint']) + self.net_recon.load_state_dict(load_x_from_safetensor(checkpoint, 'face_3drecon')) + else: + checkpoint = torch.load(sadtalker_path['path_of_net_recon_model'], map_location=torch.device(device)) + self.net_recon.load_state_dict(checkpoint['net_recon']) + + self.net_recon.eval() + self.lm3d_std = load_lm3d(sadtalker_path['dir_of_BFM_fitting']) + self.device = device + + def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False, pic_size=256): + + pic_name = os.path.splitext(os.path.split(input_path)[-1])[0] + + landmarks_path = os.path.join(save_dir, pic_name+'_landmarks.txt') + coeff_path = os.path.join(save_dir, pic_name+'.mat') + png_path = os.path.join(save_dir, pic_name+'.png') + + #load input + if not os.path.isfile(input_path): + raise ValueError('input_path must be a valid path to video/image file') + elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: + # loader for first frame + full_frames = [cv2.imread(input_path)] + fps = 25 + else: + # loader for videos + video_stream = cv2.VideoCapture(input_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + full_frames.append(frame) + if source_image_flag: + break + + x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in full_frames] + + #### crop images as the + if 'crop' in crop_or_resize.lower(): # default crop + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) + elif 'full' in crop_or_resize.lower(): + x_full_frames, crop, quad = self.propress.crop(x_full_frames, still=True if 'ext' in crop_or_resize.lower() else False, xsize=512) + clx, cly, crx, cry = crop + lx, ly, rx, ry = quad + lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) + oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx + crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad) + else: # resize mode + oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] + crop_info = ((ox2 - ox1, oy2 - oy1), None, None) + + frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames] + if len(frames_pil) == 0: + print('No face is detected in the input file') + return None, None + + # save crop info + for frame in frames_pil: + cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) + + # 2. get the landmark according to the detected face. + if not os.path.isfile(landmarks_path): + lm = self.propress.predictor.extract_keypoint(frames_pil, landmarks_path) + else: + print(' Using saved landmarks.') + lm = np.loadtxt(landmarks_path).astype(np.float32) + lm = lm.reshape([len(x_full_frames), -1, 2]) + print(len(frames_pil)) + print(frames_pil[0].size) + if not os.path.isfile(coeff_path): + # load 3dmm paramter generator from Deep3DFaceRecon_pytorch + video_coeffs, full_coeffs = [], [] + for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'): + frame = frames_pil[idx] + W,H = frame.size + lm1 = lm[idx].reshape([-1, 2]) + + if np.mean(lm1) == -1: + lm1 = (self.lm3d_std[:, :2]+1)/2. + lm1 = np.concatenate( + [lm1[:, :1]*W, lm1[:, 1:2]*H], 1 + ) + else: + lm1[:, -1] = H - 1 - lm1[:, -1] + + trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std) + + trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32) + im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0) + start_time=time.time() + with torch.no_grad(): + full_coeff = self.net_recon(im_t) + coeffs = split_coeff(full_coeff) + print(type(coeffs)) + + end_time=time.time() + ext_time=end_time-start_time + print("3DMM检测时间:%.4f秒" % ext_time) + pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs} + + pred_coeff = np.concatenate([ + pred_coeff['exp'], + pred_coeff['angle'], + pred_coeff['trans'], + trans_params[2:][None], + ], 1) + video_coeffs.append(pred_coeff) + full_coeffs.append(full_coeff.cpu().numpy()) + + semantic_npy = np.array(video_coeffs)[:,0] + + savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]}) + + return coeff_path, png_path, crop_info diff --git a/chat_anything/sad_talker/utils/safetensor_helper.py b/chat_anything/sad_talker/utils/safetensor_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..3cdbdd21e4ed656dfe2d31a57360afb3e96480b3 --- /dev/null +++ b/chat_anything/sad_talker/utils/safetensor_helper.py @@ -0,0 +1,8 @@ + + +def load_x_from_safetensor(checkpoint, key): + x_generator = {} + for k,v in checkpoint.items(): + if key in k: + x_generator[k.replace(key+'.', '')] = v + return x_generator \ No newline at end of file diff --git a/chat_anything/sad_talker/utils/text2speech.py b/chat_anything/sad_talker/utils/text2speech.py new file mode 100644 index 0000000000000000000000000000000000000000..00d165b6cc7774fd200929aafa0ff3b15916111e --- /dev/null +++ b/chat_anything/sad_talker/utils/text2speech.py @@ -0,0 +1,20 @@ +import os +import tempfile +from TTS.api import TTS + + +class TTSTalker(): + def __init__(self) -> None: + model_name = TTS.list_models()[0] + self.tts = TTS(model_name) + + def test(self, text, language='en'): + + tempf = tempfile.NamedTemporaryFile( + delete = False, + suffix = ('.'+'wav'), + ) + + self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) + + return tempf.name \ No newline at end of file diff --git a/chat_anything/sad_talker/utils/videoio.py b/chat_anything/sad_talker/utils/videoio.py new file mode 100644 index 0000000000000000000000000000000000000000..a8421884a3447031a0ebdf916249da3018e1c109 --- /dev/null +++ b/chat_anything/sad_talker/utils/videoio.py @@ -0,0 +1,49 @@ +import shutil +import uuid + +import os + +import cv2 + +def load_video_to_cv2(input_path): + video_stream = cv2.VideoCapture(input_path) + fps = video_stream.get(cv2.CAP_PROP_FPS) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + return full_frames + + +def save_video(video, audio, save_path): + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, save_path) + os.system(cmd) + return save_path + +def save_video_with_watermark(video, audio, save_path, watermark=False): + temp_file = str(uuid.uuid4())+'.mp4' + print(temp_file) + cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) + os.system(cmd) + shutil.move(temp_file, save_path) + + # if watermark is False: + # shutil.move(temp_file, save_path) + # else: + # # watermark + # try: + # ##### check if stable-diffusion-webui + # import webui + # from modules import paths + # watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" + # except: + # # get the root path of sadtalker. + # dir_path = os.path.dirname(os.path.realpath(__file__)) + # watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" + + # cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) + # os.system(cmd) + # os.remove(temp_file) \ No newline at end of file diff --git a/chat_anything/train/models/controlnet.py b/chat_anything/train/models/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1117d5d119f75132efa136232f8a322470b12453 --- /dev/null +++ b/chat_anything/train/models/controlnet.py @@ -0,0 +1,822 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalControlnetMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers..modeling_utils import ModelMixin +from diffusers..unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from diffusers.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ControlNetConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + sample = self.conv_in(sample) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/chat_anything/train/train_controlnet_sdxl.py b/chat_anything/train/train_controlnet_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..cf532580277dd042811ce94c18e2f9bb697d2f3a --- /dev/null +++ b/chat_anything/train/train_controlnet_sdxl.py @@ -0,0 +1,1247 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.20.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): + logger.info("Running validation... ") + + controlnet = accelerator.unwrap_model(controlnet) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline( + prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: openrail++ +base_model: {base_model} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="sd_xl_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[args.image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) + add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + prompt_batch = batch[args.caption_column] + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + train_dataset = get_train_dataset(args, accelerator) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, unet, controlnet, args, accelerator, weight_dtype, global_step + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/chat_anything/tts_talker/__init__.py b/chat_anything/tts_talker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/chat_anything/tts_talker/tts_edge.py b/chat_anything/tts_talker/tts_edge.py new file mode 100644 index 0000000000000000000000000000000000000000..6d98af446b5994b114b9c259ca93103608c1d801 --- /dev/null +++ b/chat_anything/tts_talker/tts_edge.py @@ -0,0 +1,77 @@ +import random +import shutil +import os + +import asyncio +import random + +import edge_tts +from edge_tts import VoicesManager +import uuid +import shutil +# # How to use this api +# #!/usr/bin/env python3 +# """ +# Example of dynamic voice selection using VoicesManager. +# """ +# import asyncio +# import random +# import edge_tts +# from edge_tts import VoicesManager +# TEXT = "Hoy es un buen día." +# OUTPUT_FILE = "spanish.mp3" +# async def amain() -> None: +# """Main function""" +# voices = await VoicesManager.create() +# voice = voices.find(Gender="Male", Language="es") +# # Also supports Locales +# # voice = voices.find(Gender="Female", Locale="es-AR") +# communicate = edge_tts.Communicate(TEXT, random.choice(voice)["Name"]) +# await communicate.save(OUTPUT_FILE) +# if __name__ == "__main__": +# loop = asyncio.get_event_loop_policy().get_event_loop() +# try: +# loop.run_until_complete(amain()) +# finally: +# loop.close() + + + + +class TTSTalker(): + def __init__(self, selected_voice, gender, language) -> None: + self.selected_voice = selected_voice + self.gender = gender + self.language = language + self.voice = asyncio.run(self.get_voice(gender, language)) + + async def get_voice(self, gender, language): + voices = await VoicesManager.create() + voices = voices.find(Gender=gender, Language=language) + voice = random.choice(voices)["Name"] + return voice + + async def amain(self, text, file, voice) -> None: + """Main function""" + + # Also supports Locales + # voice = voices.find(Gender="Female", Locale="es-AR") + communicate = edge_tts.Communicate(text, voice) + await communicate.save(file) + + + def test(self, text, audio_path=None): + if not os.path.exists(audio_path): + os.mkdir(audio_path) + voice_uuid = str(uuid.uuid4())[:5] + '.wav' + audio_file = os.path.join(audio_path, voice_uuid) + asyncio.run(self.amain(text, audio_file, self.voice)) + return audio_file + + +if __name__ == "__main__": + audio_dir = 'test' + tts_talker = TTSTalker('', 'Male', 'en').test('hello', audio_dir) + tts_talker = TTSTalker('', 'Male', 'zh').test('hello', audio_dir) + tts_talker = TTSTalker('', 'Female', 'en').test('hello',audio_dir) + tts_talker = TTSTalker('', 'Female', 'zh').test('hello', audio_dir) \ No newline at end of file diff --git a/chat_anything/tts_talker/tts_voicechanger.py b/chat_anything/tts_talker/tts_voicechanger.py new file mode 100644 index 0000000000000000000000000000000000000000..5e032ec9a8b485c26a08ff771c6d1cec4e42ed35 --- /dev/null +++ b/chat_anything/tts_talker/tts_voicechanger.py @@ -0,0 +1,78 @@ +import random +import shutil +import os +from gradio_client import Client +client = Client("http://127.0.0.1:7860/") +# How to use this new TTS Client? I leave the gradio api demo page as a reference +# client = Client("http://127.0.0.1:7860/") +# result = client.predict( +# "Howdy!", # str in '请填写您想要转换的文本(中英皆可)' Textbox component +# "Bilibili - 一清清清,Bilibili - 一清清清", # str (Option from: [('Bilibili - 一清清清', 'Bilibili - 一清清清'), ('ALL - Bob Sponge', 'ALL - Bob Sponge'), ('ALL - Ariana Grande', 'ALL - Ariana Grande'), ('ALL - Stefanie Sun', 'ALL - Stefanie Sun')]) +# in '请选择您的AI歌手(必选)' Dropdown component +# "Microsoft Adri Online (Natural) - Afrikaans (South Africa) (Female),Microsoft Adri Online (Natural) - Afrikaans (South Africa) (Female)", # str (Option from: [('Microsoft Adri Online (Natural) - Afrikaans (South Africa) (Female)', 'Microsoft Adri Online (Natural) - Afrikaans (South Africa) (Female)'), ('Microsoft Willem Online (Natural) - Afrikaans (South Africa) (Male)', 'Microsoft Willem Online (Natural) - Afrikaans (South Africa) (Male)'), ('Microsoft Anila Online (Natural) - Albanian (Albania) (Female)', 'Microsoft Anila Online (Natural) - Albanian (Albania) (Female)'), ('Microsoft Ilir Online (Natural) - Albanian (Albania) (Male)', 'Microsoft Ilir Online (Natural) - Albanian (Albania) (Male)'), ('Microsoft Ameha Online (Natural) - Amharic (Ethiopia) (Male)', 'Microsoft Ameha Online (Natural) - Amharic (Ethiopia) (Male)'), ('Microsoft Mekdes Online (Natural) - Amharic (Ethiopia) (Female)', 'Microsoft Mekdes Online (Natural) - Amharic (Ethiopia) (Female)'), +# ('Microsoft Amina Online (Natural) - Arabic (Algeria) (Female)', 'Microsoft Amina Online (Natural) - Arabic (Algeria) (Female)'), ('Microsoft Ismael Online (Natural) - Arabic (Algeria) (Male)', 'Microsoft Ismael Online (Natural) - Arabic (Algeria) (Male)'), ('Microsoft Ali Online (Natural) - Arabic (Bahrain) (Male)', 'Microsoft Ali Online (Natural) - Arabic (Bahrain) (Male)'), ('Microsoft Laila Online (Natural) - Arabic (Bahrain) (Female)', 'Microsoft Laila Online (Natural) - Arabic (Bahrain) (Female)'), ('Microsoft Salma Online (Natural) - Arabic (Egypt) (Female)', 'Microsoft Salma Online (Natural) - Arabic (Egypt) (Female)'), ('Microsoft Shakir Online (Natural) - Arabic (Egypt) (Male)', 'Microsoft Shakir Online (Natural) - Arabic (Egypt) (Male)'), ...) +# +# in '请选择一个相应语言的说话人' Dropdown component +# -24, # int | float (numeric value between -24 and 24) +# in 'Pitch' Slider component +# "pm", # str in 'f0 methods' Radio component +# 0, # int | float (numeric value between 0 and 1) +# in 'Feature ratio' Slider component +# 0, # int | float (numeric value between 0 and 7) +# in 'Filter radius' Slider component +# 0, # int | float (numeric value between 0 and 1) +# in 'Volume envelope mix rate' Slider component +# "Disable resampling,Disable resampling", # str (Option from: [('Disable resampling', 'Disable resampling'), ('16000', '16000'), ('22050', '22050'), ('44100', '44100'), ('48000', '48000')]) +# in 'Resample rate' Dropdown component +# api_name="/tts_conversion" +# ) +# print(result) + +TTS_MODELS = { + "male":{ + "Chinese": "Microsoft Yunyang Online (Natural) - Chinese (Mainland) (Male)", + "English": "Microsoft Eric Online (Natural) - English (United States) (Male)", + "Japanese": "Microsoft Keita Online (Natural) - Japanese (Japan) (Male)", + }, + "female":{ + "Chinese": "Microsoft Xiaoyi Online (Natural) - Chinese (Mainland) (Female)", + "English": "Microsoft Ana Online (Natural) - English (United States) (Female)", + "Japanese": "Microsoft Nanami Online (Natural) - Japanese (Japan) (Female)", + } +} + + +class TTSTalker(): + def __init__(self,selected_voice, gender, language) -> None: + self.selected_voice = selected_voice + self.gender = gender + self.language = language + + def test(self, text, audio_path=None): + self.gender = random.choice(['male', 'female']) if self.gender not in TTS_MODELS else self.gender + languages = TTS_MODELS[self.gender].keys() + self.language = random.choice(languages) if self.language not in languages else self.language + tts_model = TTS_MODELS[self.gender][self.language] + result = client.predict( + text, # str in '请填写您想要转换的文本(中英皆可)' Textbox component + self.selected_voice, # str (Option from: [('Bilibili - 一清清清', 'Bilibili - 一清清清'), ('ALL - Bob Sponge', 'ALL - Bob Sponge'), ('ALL - Ariana Grande', 'ALL - Ariana Grande'), ('ALL - Stefanie Sun', 'ALL - Stefanie Sun')]) in '请选择您的AI歌手(必选)' Dropdown component + tts_model, # in '请选择一个相应语言的说话人' Dropdown component + 0, # int | float (numeric value between -24 and 24) in 'Pitch' Slider component + "pm", # str in 'f0 methods' Radio component + 0, # int | float (numeric value between 0 and 1) in 'Feature ratio' Slider component + 0, # int | float (numeric value between 0 and 7) in 'Filter radius' Slider component + 0, # int | float (numeric value between 0 and 1) in 'Volume envelope mix rate' Slider component + "Disable resampling", # str (Option from: [('Disable resampling', 'Disable resampling'), ('16000','16000'), ('22050', '22050'), ('44100', '44100'), ('48000', '48000')]) in 'Resample rate' Dropdown component + api_name="/tts_conversion" + ) + print(result[1]) + print(result) + if result[1] == 'Success': + if not os.path.exists(audio_path): + os.makedirs(audio_path) + output_path = os.path.join(audio_path, 'tempfile.mp3') + print(output_path) + shutil.copy(result[0], output_path) + return output_path + else: + raise ValueError("failed with SVC") diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c502142fffe919fc54551338638d20a9886f1d87 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,44 @@ +name: chatanything +channels: + - menpo + - conda-forge + - pytorch + - defaults +dependencies: + - python==3.8.10 + - pip + - pip: + - --extra-index-url https://download.pytorch.org/whl/cu117 + - torch==2.0.1+cu117 + - torchvision==0.15.2+cu117 + - torchaudio==2.0.2 + - librosa + - gfpgan + - facexlib==0.3.0 + - face_alignment==1.3.5 + - yacs==0.1.8 + - numba + - openai-whisper + - kornia + - diffusers + - transformers + - dlib + - huggingface-hub + - pydantic==1.10.9 + - langchain==0.0.310 + - gradio==3.41.0 + - gradio-client==0.5.0 + - omegaconf + - openai + - opencv-python + - imageio-ffmpeg + # for voice-changer + - moviepy + - edge-tts + - fairseq + - praat-parselmouth + - pyworld + - faiss-cpu + - accelerate + - ffmpeg + - ffmpeg-python \ No newline at end of file diff --git a/python_scripts/convert_original_controlnet_to_diffusers.py b/python_scripts/convert_original_controlnet_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..600d943d1bcb0b4bb2b4ffcd93dc9d36562b9808 --- /dev/null +++ b/python_scripts/convert_original_controlnet_to_diffusers.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """ + +import argparse + +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--image_size", + default=512, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + + args = parser.parse_args() + + controlnet = download_controlnet_from_original_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, + ) + + controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) \ No newline at end of file diff --git a/python_scripts/convert_original_stable_diffusion_to_diffusers.py b/python_scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..4130da2b28cc77aa4ebf22953c78e96ec4125a34 --- /dev/null +++ b/python_scripts/convert_original_stable_diffusion_to_diffusers.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the LDM checkpoints. """ + +import argparse +import importlib + +import torch + +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--config_files", + default=None, + type=str, + help="The YAML config file corresponding to the architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", + ) + parser.add_argument( + "--pipeline_type", + default=None, + type=str, + help=( + "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" + ". If `None` pipeline will be automatically inferred." + ), + ) + parser.add_argument( + "--image_size", + default=None, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--prediction_type", + default=None, + type=str, + help=( + "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" + " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--stable_unclip", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", + ) + parser.add_argument( + "--stable_unclip_prior", + type=str, + default=None, + required=False, + help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", + ) + parser.add_argument( + "--clip_stats_path", + type=str, + help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", + required=False, + ) + parser.add_argument( + "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." + ) + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--vae_path", + type=str, + default=None, + required=False, + help="Set to a path, hub id to an already converted vae to not convert it again.", + ) + parser.add_argument( + "--pipeline_class_name", + type=str, + default=None, + required=False, + help="Specify the pipeline class name", + ) + + args = parser.parse_args() + + if args.pipeline_class_name is not None: + library = importlib.import_module("diffusers") + class_obj = getattr(library, args.pipeline_class_name) + pipeline_class = class_obj + else: + pipeline_class = None + + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=args.checkpoint_path, + original_config_file=args.original_config_file, + config_files=args.config_files, + image_size=args.image_size, + prediction_type=args.prediction_type, + model_type=args.pipeline_type, + extract_ema=args.extract_ema, + scheduler_type=args.scheduler_type, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + stable_unclip=args.stable_unclip, + stable_unclip_prior=args.stable_unclip_prior, + clip_stats_path=args.clip_stats_path, + controlnet=args.controlnet, + vae_path=args.vae_path, + pipeline_class=pipeline_class, + ) + + if args.half: + pipe.to(torch_dtype=torch.float16) + + if args.controlnet: + # only save the controlnet model + pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + else: + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) \ No newline at end of file diff --git a/python_scripts/prepare_models.py b/python_scripts/prepare_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb7b4af7bd496df0fdda58197a90715e05bcdc7 --- /dev/null +++ b/python_scripts/prepare_models.py @@ -0,0 +1,54 @@ +import os +import os.path as osp +import requests +import shutil +from huggingface_hub import snapshot_download, HfApi +# from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt +from facexlib.utils import load_file_from_url +from facexlib.detection import init_detection_model + +def hf_download_dir(repo_id, dirname): + api = HfApi() + space_list = api.list_repo_files(repo_id=repo_id) + target_list = [target for target in space_list if target.startswith(dirname) ] + + print(target_list) + for filename in target_list: + print(f'downloading {filename}') + api.hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir='.', + local_dir_use_symlinks=True, + ) + + +MODEL_DIR='./MODELS' +os.makedirs(MODEL_DIR, exist_ok=True) + +def prepare_sadtalker_models(): + snapshot_download(repo_id='vinthony/SadTalker', local_dir=osp.join(MODEL_DIR, 'SadTalker'), local_dir_use_symlinks=True) + load_file_from_url( + url='https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth', + model_dir='facexlib/weights', + progress=True, file_name=None, save_dir=osp.join(MODEL_DIR, 'gfpgan/weights',)) + init_detection_model('retinaface_resnet50', half=False,device='cpu', model_rootpath=osp.join(MODEL_DIR, 'gfpgan/weights',)) + +def prepare_face_generator_models(): + # from all source repo + # snapshot_download(repo_id="georgefen/Face-Landmark-ControlNet", local_dir=osp.join(MODEL_DIR, "Face-Landmark-ControlNet"), allow_patterns=["models_for_diffusers/*"], local_dir_use_symlinks=True) + # snapshot_download(repo_id="runwayml/stable-diffusion-v1-5", local_dir=osp.join(MODEL_DIR, "stable-diffusion-v1-5"), allow_patterns=["*.bin", '*.json', '*.txt'], ignore_patterns=['safety_checker'],local_dir_use_symlinks=True) + # snapshot_download(repo_id="xiaolxl/GuoFeng3", local_dir=osp.join(MODEL_DIR, "GuoFeng3"), allow_patterns=["*.bin", '*.json', '*.txt'], ignore_patterns=['safety_checker*'],local_dir_use_symlinks=True) + # snapshot_download(repo_id="simhuangxi/MoXin", local_dir=osp.join(MODEL_DIR, "MoXin"),local_dir_use_symlinks=True) + # snapshot_download(repo_id="diffusers/controlnet-canny-sdxl-1.0", local_dir=osp.join(MODEL_DIR, "controlnet-canny-sdxl-1.0"), ignore_patterns=['*.bin'], local_dir_use_symlinks=True) + # snapshot_download(repo_id="stablediffusionapi/anything-v5", local_dir=osp.join(MODEL_DIR, "anything-v5"), ignore_patterns=['*.bin'], local_dir_use_symlinks=True) + # snapshot_download( + # repo_id="ermu2001/ChatAnything", + # local_dir='.', + # local_dir_use_symlinks=True, + # ) + hf_download_dir('ermu2001/ChatAnything', 'MODELS') + +if __name__ == "__main__": + prepare_sadtalker_models() + prepare_face_generator_models() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..72e14897139a817a1a5790450351411563058532 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,174 @@ +absl-py==2.0.0 +addict==2.4.0 +aiofiles==23.2.1 +aiohttp==3.8.6 +aiosignal==1.3.1 +altair==5.1.2 +annotated-types==0.6.0 +antlr4-python3-runtime==4.8 +anyio==3.7.1 +async-timeout==4.0.3 +attrs==23.1.0 +audioread==3.0.1 +basicsr==1.4.2 +bitarray==2.8.2 +cachetools==5.3.1 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.0 +click==8.1.7 +cmake==3.27.6 +colorama==0.4.6 +contourpy==1.1.1 +cycler==0.12.1 +Cython==3.0.3 +dataclasses-json==0.6.1 +decorator==4.4.2 +diffusers==0.21.4 +dlib==19.24.2 +edge-tts==6.1.8 +exceptiongroup==1.1.3 +face-alignment==1.3.5 +facexlib==0.3.0 +fairseq==0.12.2 +faiss-cpu==1.7.4 +fastapi==0.103.2 +ffmpy==0.3.1 +filelock==3.12.4 +filterpy==1.4.5 +fonttools==4.43.1 +frozenlist==1.4.0 +fsspec==2023.9.2 +future==0.18.3 +gfpgan==1.3.8 +google-auth==2.23.2 +google-auth-oauthlib==1.0.0 +gradio==3.47.1 +gradio_client==0.6.0 +greenlet==3.0.0 +grpcio==1.59.0 +h11==0.14.0 +httpcore==0.18.0 +httpx==0.25.0 +huggingface-hub==0.17.3 +hydra-core==1.0.7 +idna==3.4 +imageio==2.31.5 +imageio-ffmpeg==0.4.9 +importlib-metadata==6.8.0 +importlib-resources==6.1.0 +Jinja2==3.1.2 +joblib==1.3.2 +jsonpatch==1.33 +jsonpointer==2.4 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +kornia==0.7.0 +langchain==0.0.311 +langsmith==0.0.43 +lazy_loader==0.3 +librosa==0.10.1 +lit==17.0.2 +llvmlite==0.41.0 +lmdb==1.4.1 +lxml==4.9.3 +Markdown==3.5 +MarkupSafe==2.1.3 +marshmallow==3.20.1 +matplotlib==3.7.3 +more-itertools==10.1.0 +moviepy==1.0.3 +mpmath==1.3.0 +msgpack==1.0.7 +multidict==6.0.4 +mypy-extensions==1.0.0 +networkx==3.1 +numba==0.58.0 +numpy==1.24.4 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cufft-cu11==10.9.0.58 +nvidia-curand-cu11==10.2.10.91 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-nccl-cu11==2.14.3 +nvidia-nvtx-cu11==11.7.91 +oauthlib==3.2.2 +omegaconf==2.0.6 +openai==0.28.1 +openai-whisper==20230918 +opencv-python==4.8.1.78 +orjson==3.9.7 +packaging==23.2 +pandas==2.0.3 +Pillow==10.0.1 +pkgutil_resolve_name==1.3.10 +platformdirs==3.11.0 +pooch==1.7.0 +portalocker==2.8.2 +praat-parselmouth==0.4.3 +proglog==0.1.10 +protobuf==4.24.4 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycparser==2.21 +pydantic==2.4.2 +pydantic_core==2.10.1 +pydub==0.25.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytz==2023.3.post1 +PyWavelets==1.4.1 +pyworld==0.3.4 +PyYAML==6.0.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +requests-oauthlib==1.3.1 +rpds-py==0.10.4 +rsa==4.9 +sacrebleu==2.3.1 +safetensors==0.4.0 +scikit-image==0.21.0 +scikit-learn==1.3.1 +scipy==1.10.1 +semantic-version==2.10.0 +six==1.16.0 +sniffio==1.3.0 +soundfile==0.12.1 +soxr==0.3.7 +SQLAlchemy==2.0.21 +starlette==0.27.0 +sympy==1.12 +tabulate==0.9.0 +tb-nightly==2.14.0a20230808 +tenacity==8.2.3 +tensorboard-data-server==0.7.1 +threadpoolctl==3.2.0 +tifffile==2023.7.10 +tiktoken==0.3.3 +tokenizers==0.14.1 +tomli==2.0.1 +toolz==0.12.0 +torch==2.0.1 +torchaudio==2.0.2 +torchvision==0.15.2 +tqdm==4.66.1 +transformers==4.34.0 +triton==2.0.0 +typing-inspect==0.9.0 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.6 +uvicorn==0.23.2 +websockets==11.0.3 +Werkzeug==3.0.0 +yacs==0.1.8 +yapf==0.40.2 +yarl==1.9.2 +zipp==3.17.0 diff --git a/resources/images/annie.jpg b/resources/images/annie.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d61c68f6495a75a22bc60661190b5e9f5dd671d7 Binary files /dev/null and b/resources/images/annie.jpg differ diff --git a/resources/images/faces/0.jpg b/resources/images/faces/0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..efbb5c0b0fe97e3f5edef242ff8ae67a2a341479 Binary files /dev/null and b/resources/images/faces/0.jpg differ diff --git a/resources/images/faces/1.jpg b/resources/images/faces/1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..69de396fed4c744daf5d596849bab3a55ca1571f Binary files /dev/null and b/resources/images/faces/1.jpg differ diff --git a/resources/images/faces/2.jpg b/resources/images/faces/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39b10475b155b95b8a2b49f65b53dee7d01201f0 Binary files /dev/null and b/resources/images/faces/2.jpg differ diff --git a/resources/images/faces/3.jpg b/resources/images/faces/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fa6597e58bf9278dfec78e9a262aaf2f5a3e6123 Binary files /dev/null and b/resources/images/faces/3.jpg differ diff --git a/resources/images/lenna.jpg b/resources/images/lenna.jpg new file mode 100644 index 0000000000000000000000000000000000000000..86ac9c5cac3d0bf7e602cef6533dc2f06a41b688 Binary files /dev/null and b/resources/images/lenna.jpg differ diff --git a/resources/images/watermelon.jpg b/resources/images/watermelon.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a62b424ed5ce949e08e2733ee1612d644e5e2b97 Binary files /dev/null and b/resources/images/watermelon.jpg differ diff --git a/resources/models_personality.yaml b/resources/models_personality.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e561c1ee707a7a067d5b49834792a45424c6987f --- /dev/null +++ b/resources/models_personality.yaml @@ -0,0 +1,64 @@ +prompt: >- + Select one of the following models for image generation for the given concept. + The user wants to chat with such object. + The concept is presented by the personality of a intended object for the user to chat with. + Each model experts at generating one specific style of images. + The goal is to interpret the purpose of user and choose base on the style that most suits the given concept to help generate the most appropriate image that accords with the user's input. + Select one of the options based on the relative association of the description of each model and the given concept to satisfy user + + + Personality: + + +models: + # sd1.5: + # desc: >- + # A realistic photo style image generator experts at realistic photo generation. The model sucks. + # Only choose when the user wants realistic image outputs. + # lora_path: + # model_dir: MODELS/stable-diffusion-v1-5 + # prompt_template: A portrait of {}, fine face, nice looking + # negative_prompt: "" + + GameIconInstitute_mode: + desc: >- + A artistic cartooning style image generator experts at generating bizarre concepts. The generated + images usually contains exaggerated concepts. The artistic style of images uses Vibrant Colors and + are full of energy. Artistic cartoonists often use visual gags and puns to impress people. Storytelling + and Narrative is a main property for this style. + lora_path: + model_dir: MODELS/GameIconInstitute_mode + prompt_template: A portrait of a {}, fine face, nice looking + negative_prompt: easynegative,Low resolution,Low quality, Opened Mouth + + anything-v5: + desc: >- + A Japanese animate style image generator. Choose when you imagen the user + to expect a Cute character. Especially when user wants cute animate girls. + lora_path: + model_dir: MODELS/anything-v5 + prompt_template: actual 8K portrait photo of {} girl, portrait, happy colors, symmetrical, detailed face, stanley artgerm lau, wlop, rossdraws, concept art, digital painting, looking into camera + negative_prompt: painting, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, skinny, glitchy, double torso, extra arms, extra hands, mangled fingers, missing lips, ugly face, distorted face, extra legs, anime + + dream_shaper: + desc: >- + A artisic realistic beauty image generator experts at artisic portrait generation. + This model Emphasis on Aesthetics and the portrayal of beauty. Achieving artistic + realistic beauty typically requires a high level of technical skill. The model often + have room for creative interpretation for real objects. Choose when you imagen the user + wants highly Aesthetic image of intended object. + lora_path: + model_dir: MODELS/DreamShaper + prompt_template: > + fashion photography portrait of {}, 3d render, cgi, symetrical, octane render, 35mm, bokeh, 9:16, (intricate details:1.12), hdr, (intricate details, hyperdetailed:1.15), (natural skin texture, hyperrealism, soft light, sharp:1.2), detailed + negative_prompt: "BadDream, UnrealisticDream" + + 3D_Animation_Diffusion: + desc: >- + A 3D image generator. 3D style image generation is usually like generated out of game engine. + The style would be unlimited to specific concepts. Choose when the concept is enthusiastic and energitic. + + lora_path: + model_dir: MODELS/3D_Animation_Diffusion + prompt_template: A portrait of {}, fine face, nice looking + negative_prompt: "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, Ugly, deformed, ugly face, low face quality," \ No newline at end of file diff --git a/resources/prompts/animal.txt b/resources/prompts/animal.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5deba518ecc78e1480621f0e961db9481033451 --- /dev/null +++ b/resources/prompts/animal.txt @@ -0,0 +1,50 @@ +Wise owl professor +Majestic lion king +Urban fox detective +Scholarly tortoise librarian +Playful dolphin entertainer +Regal deer monarch +Cunning snake charmer +Philosophical panda thinker +Adventurous kangaroo explorer +Grumpy bear chef +Mysterious cat sorcerer +Energetic squirrel journalist +Brave wolf warrior +Jazz singer bird of paradise +Gentle giraffe healer +Steampunk otter inventor +Artistic zebra painter +Intrepid eagle pilot +Humorous hyena comedian +Traditional elephant elder +Noble stallion knight +Curious monkey scientist +Charming penguin diplomat +Wise buffalo sage +Cheeky raccoon thief +Ancient dragon seer +Exuberant tiger dancer +Stoic rhino guard +Dreamy koala poet +Alert meerkat watchman +Fiery jaguar athlete +Graceful swan ballerina +Passionate leopard rocker +Loyal dog sheriff +Dignified moose chieftain +Humble donkey farmer +Vibrant toucan broadcaster +Grizzled shark sailor +Gentlemanly frog prince +Cheery hummingbird musician +Nomadic camel trader +Ambitious ant queen +Cunning bat detective +Kindhearted rabbit doctor +Noble hawk judge +Relaxed sloth yogi +Joyful seal comedian +Resourceful beaver engineer +Mysterious crow mystic +Sleek panther spy \ No newline at end of file diff --git a/resources/prompts/bags.txt b/resources/prompts/bags.txt new file mode 100644 index 0000000000000000000000000000000000000000..51ca45ef0047d7f265c929e72762689cae898bc0 --- /dev/null +++ b/resources/prompts/bags.txt @@ -0,0 +1,50 @@ +Grandiose leather briefcase executive +Rugged backpack explorer +Dainty clutch ballerina +Mysterious messenger bag courier +Energetic gym bag athlete +Vintage suitcase historian +Suave duffel jetsetter +Loyal tote librarian +Elegant satchel aristocrat +Adventurous fanny pack tourist +Beachy straw tote sunseeker +Luxurious velvet purse diva +Rugged canvas sack miner +Whimsical drawstring pouch magician +Regal beaded clutch empress +Protective armored backpack knight +Casual sling bag student +Flamboyant sequined bag performer +Minimalist laptop bag tech guru +Rustic hobo bag musician +Sporty hydration pack runner +Compact makeup bag beautician +Glossy patent leather bag model +Quirky patchwork bag artist +Majestic chest treasure keeper +Trendy crossbody bag fashionista +Classic doctor's bag healer +Multi-pocketed fishing bag angler +Flamboyant fringe bag dancer +Durable tool bag mechanic +Posh evening bag socialite +Stealthy camouflage bag ranger +Whimsical coin purse collector +Heavy-duty cargo bag shipper +Playful diaper bag nanny +Lightweight parachute bag skydiver +Vibrant printed pouch traveler +Timeless brown bag philosopher +Elegantly embroidered bag artisan +Sturdy belt bag warrior +Eco-friendly reusable shopper environmentalist +Quaint knapsack writer +Lustrous silk bag diviner +Casual cooler bag beachcomber +Refined attache case diplomat +Delicate reticule historian +Geometric patterned bag mathematician +Nostalgic school bag scholar +Intricately woven bag weaver +Transparent clear bag visionary \ No newline at end of file diff --git a/resources/prompts/cartoon.txt b/resources/prompts/cartoon.txt new file mode 100644 index 0000000000000000000000000000000000000000..abc2d49d315631cded0c631ba0a30c0ae4587daa --- /dev/null +++ b/resources/prompts/cartoon.txt @@ -0,0 +1,50 @@ +Bumbling professor owl +Intrepid explorer mouse +Brooding superhero bat +Whimsical fairy squirrel +Mischievous trickster fox +Confident leader lion +Laid-back surfer turtle +Clumsy robotic dog +Mysterious sorcerer cat +Space-traveling astronaut bunny +Jovial baker bear +Spunky skateboarder duck +Gritty detective raccoon +Shy library bookworm +Boisterous pirate parrot +Athletic soccer-playing kangaroo +Dashing adventurer wolf +Sci-fi time traveler otter +Serene monk panda +Steampunk engineer giraffe +Fearless warrior princess tiger +Genius scientist monkey +Dreamy star-gazing elephant +Sassy jazz singer peacock +Outlandish alien deer +Old west sheriff armadillo +Regal king cobra +Booming opera singer walrus +Eco-warrior recycle frog +Glamorous movie star flamingo +Urbane city dweller zebra +Spicy salsa dancer iguana +Mythical griffin librarian +Enigmatic spy chameleon +Breezy beachcomber seagull +Brawny strongman rhino +Medieval jester jackalope +Upbeat DJ hedgehog +Elusive ninja octopus +Futuristic racer penguin +Gentle healer dove +Kooky inventor platypus +Impulsive daredevil squirrel +Mystic guru manatee +Small-town sheriff snail +Feisty rock star rooster +Regal emperor pufferfish +Quirky chef mole +Galactic guardian orca +Chilled-out yogi sloth diff --git a/resources/prompts/clothes.txt b/resources/prompts/clothes.txt new file mode 100644 index 0000000000000000000000000000000000000000..51ca45ef0047d7f265c929e72762689cae898bc0 --- /dev/null +++ b/resources/prompts/clothes.txt @@ -0,0 +1,50 @@ +Grandiose leather briefcase executive +Rugged backpack explorer +Dainty clutch ballerina +Mysterious messenger bag courier +Energetic gym bag athlete +Vintage suitcase historian +Suave duffel jetsetter +Loyal tote librarian +Elegant satchel aristocrat +Adventurous fanny pack tourist +Beachy straw tote sunseeker +Luxurious velvet purse diva +Rugged canvas sack miner +Whimsical drawstring pouch magician +Regal beaded clutch empress +Protective armored backpack knight +Casual sling bag student +Flamboyant sequined bag performer +Minimalist laptop bag tech guru +Rustic hobo bag musician +Sporty hydration pack runner +Compact makeup bag beautician +Glossy patent leather bag model +Quirky patchwork bag artist +Majestic chest treasure keeper +Trendy crossbody bag fashionista +Classic doctor's bag healer +Multi-pocketed fishing bag angler +Flamboyant fringe bag dancer +Durable tool bag mechanic +Posh evening bag socialite +Stealthy camouflage bag ranger +Whimsical coin purse collector +Heavy-duty cargo bag shipper +Playful diaper bag nanny +Lightweight parachute bag skydiver +Vibrant printed pouch traveler +Timeless brown bag philosopher +Elegantly embroidered bag artisan +Sturdy belt bag warrior +Eco-friendly reusable shopper environmentalist +Quaint knapsack writer +Lustrous silk bag diviner +Casual cooler bag beachcomber +Refined attache case diplomat +Delicate reticule historian +Geometric patterned bag mathematician +Nostalgic school bag scholar +Intricately woven bag weaver +Transparent clear bag visionary \ No newline at end of file diff --git a/resources/prompts/fruits.txt b/resources/prompts/fruits.txt new file mode 100644 index 0000000000000000000000000000000000000000..86e19a3308d7fd549d64f56a7bc42d3d92305f98 --- /dev/null +++ b/resources/prompts/fruits.txt @@ -0,0 +1,50 @@ +Wise apple elder +Confident banana leader +Mysterious blueberry mystic +Gentle peach poet +Charismatic grape entertainer +Regal pineapple monarch +Scholarly lemon professor +Energetic strawberry athlete +Philosopher kiwi thinker +Urban orange detective +Charming cherry diplomat +Ambitious passionfruit entrepreneur +Elegant pear dancer +Brave watermelon warrior +Dreamy lychee artist +Jubilant raspberry musician +Steampunk blackberry inventor +Cheerful melon comedian +Serene coconut guru +Vibrant mango broadcaster +Stoic fig guard +Noble pomegranate knight +Exuberant plum performer +Graceful apricot ballerina +Humorous papaya clown +Mysterious currant sorcerer +Loyal blueberry counselor +Traditional elderberry historian +Curious gooseberry scientist +Dapper cranberry gentleman +Joyful tangerine singer +Gentle guava healer +Wise date sage +Sleek avocado explorer +Artistic dragon fruit painter +Dignified olive statesman +Graceful persimmon figure skater +Fiery chili pepper activist +Passionate blackcurrant rockstar +Resourceful lime engineer +Mystical starfruit astronomer +Bold grapefruit pilot +Serendipitous passionfruit dreamer +Scholarly pomegranate librarian +Confident kiwi coach +Humble cherry tomato gardener +Regal durian empress +Vibrant boysenberry radio host +Gentle fig monk +Jubilant clementine ringmaster \ No newline at end of file diff --git a/resources/prompts/office_accessories.txt b/resources/prompts/office_accessories.txt new file mode 100644 index 0000000000000000000000000000000000000000..38358e6167f4a67ef1759077d71e3c63cbc56f98 --- /dev/null +++ b/resources/prompts/office_accessories.txt @@ -0,0 +1,50 @@ +Scholarly bookshelf organizer +Charismatic pen presenter +Wise calendar elder +Energetic highlighter broadcaster +Majestic desk chair monarch +Intrepid paperclip navigator +Artistic paintbrush designer +Resilient stapler guardian +Vibrant post-it note entertainer +Mysterious inkwell sage +Noble binder overseer +Loyal clipboard manager +Curious magnifying glass detective +Charming rubber band diplomat +Graceful quill dancer +Harmonious musical ruler +Jubilant tape dispenser comedian +Mellow coffee mug thinker +Bold scissors strategist +Steampunk typewriter storyteller +Radiant lamp illuminator +Wise old leather journal historian +Trustworthy lockbox treasurer +Playful eraser trickster +Gentle tissue box comforter +Enigmatic calculator oracle +Resolute file cabinet archivist +Jubilant stamp perforator +Sleek mouse navigator +Adventurous envelope traveler +Regal trophy leader +Whimsical floating pencil magician +Elegant letter opener fencer +Enchanting snow globe dreamer +Ambitious printer publisher +Refined ink blotter gentleman +Confident pushpin point person +Passionate fiery candle illuminator +Loyal telephone communicator +Cheerful thumbtack decorator +Patient hourglass timekeeper +Resilient hole punch warrior +Serene plant pot nurturer +Humorous spinning top jester +Regenerative whiteboard brainstormer +Versatile multi-tool problem solver +Mysterious hidden compartment keeper +Sturdy shelf supporter +Dynamic rolling chair racer +Swift express courier envelope \ No newline at end of file diff --git a/resources/prompts/plants.txt b/resources/prompts/plants.txt new file mode 100644 index 0000000000000000000000000000000000000000..b49950ac03ad3081374bab43213b3bc6122454e5 --- /dev/null +++ b/resources/prompts/plants.txt @@ -0,0 +1,50 @@ +Majestic oak tree elder +Enigmatic fern oracle +Charismatic sunflower broadcaster +Artistic rose poet +Eloquent orchid diplomat +Adventurous ivy explorer +Gracious lily dancer +Scholarly bamboo philosopher +Vibrant tulip entertainer +Resilient cactus cowboy +Noble cedar monarch +Mysterious moss shaman +Intrepid daisy pilot +Gentle lavender healer +Cheerful dandelion comedian +Stoic pine sentinel +Radiant marigold performer +Urban jungle ficus therapist +Curious fern botanist +Mellow chamomile yogi +Brave thistle warrior +Graceful chrysanthemum ballerina +Harmonious bonsai conductor +Wise olive tree sage +Passionate poppy singer +Scholarly maple librarian +Enchanting peony enchantress +Regal magnolia duchess +Loyal willow guardian +Joyful hydrangea artist +Mystical lotus meditator +Dreamy wisteria novelist +Ambitious redwood strategist +Urban sprout entrepreneur +Regenerative aloe doctor +Steampunk venus flytrap inventor +Vibrant hibiscus broadcaster +Serene fern forest monk +Playful clover jester +Gentle cornflower therapist +Insightful birch counselor +Resolute horsetail architect +Fiery chili plant activist +Refined tea plant connoisseur +Ethereal ghost plant mystic +Dapper cattail gentleman +Whimsical seaweed mermaid +Grizzled bristlecone elder +Charismatic coffee plant barista +Harmonious grasshopper grass musician \ No newline at end of file diff --git a/resources/prompts/realistic.txt b/resources/prompts/realistic.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2f8a86ac4bb91254f5805af34a5f6d90f0dc929 --- /dev/null +++ b/resources/prompts/realistic.txt @@ -0,0 +1,50 @@ +Renaissance artist +Cyberpunk hacker +Medieval knight +1920s jazz singer +Astronaut explorer +Tribal elder +Victorian detective +1980s pop star +Desert nomad +Pirate captain +Samurai philosopher +Regency-era novelist +Steampunk inventor +WWII pilot +Film noir actress +Olympic champion +1970s activist +Ancient Egyptian priest +Roaring '20s flapper +Hollywood golden age director +Native American chief +Gothic novelist +Cold War spy +Elizabethan playwright +1950s rockabilly +Silk Road trader +Roman gladiator +Harlem Renaissance poet +Prohibition gangster +Revolutionary war general +Belle Époque ballerina +1990s tech entrepreneur +Viking chieftain +Mongolian horse archer +Edwardian suffragette +Jazz Age novelist +Disco-era DJ +Ancient Greek philosopher +Civil Rights leader +Wild West sheriff +Space colony pioneer +Colonial tea merchant +1960s counterculture icon +Medieval alchemist +Golden age of piracy navigator +1930s Hollywood starlet +Renaissance sculptor +Post-apocalyptic survivor +WWII resistance fighter +Industrial revolution entrepreneur diff --git a/resources/readme/show.png b/resources/readme/show.png new file mode 100644 index 0000000000000000000000000000000000000000..f6e3813a1772f29d2ad7cfad96b47ea6555b5f88 --- /dev/null +++ b/resources/readme/show.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5866539940cedcb6305ddd4af372012e5cdaa3e3e4f3a60e53d65e7aed7554da +size 6473175 diff --git a/resources/voices_edge.yaml b/resources/voices_edge.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b908d11144bf16b65fee2e47e4eecead0d2617ad --- /dev/null +++ b/resources/voices_edge.yaml @@ -0,0 +1,33 @@ +prompt: >- + Select one of the following voice based on the given concept's personality. + The user wants to chat with such concept. + Notice the language input of the user and mainly consider the language and gender of the voice. + You must choose one voice name based on the description of each model and the concept. + + Personality: + + +models: + chinese_male: + desc: > + This is a male speaking Chinese + gender: Male + language: zh + + english_male: + desc: > + this is a male speaking English. + gender: Male + language: en + + chinese_female: + desc: > + This is a female speaking Chinese + gender: Female + language: zh + + english_female: + desc: > + this is a female speaking English. + gender: Female + language: zh \ No newline at end of file diff --git a/resources/voices_voicechanger.yaml b/resources/voices_voicechanger.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f86739c4545643678ac0b9fc016c6c8fafa22d5 --- /dev/null +++ b/resources/voices_voicechanger.yaml @@ -0,0 +1,13 @@ +models: + ALL - Stefanie Sun: + # model_A: + desc: > + This is Stefanie Sun's voice,a woman's voice, a captivating singer's voice that enchants listeners. Stefanie Sun's voice can be described as a melodic and emotive masterpiece. It possesses a velvety smoothness with a hint of warmth, allowing her vocals to effortlessly glide through the air. Her voice carries a unique timbre that is instantly recognizable, captivating listeners with its rich and enchanting quality. + model_name: ALL - Stefanie Sun + sex: female + + ALL - Bob Sponge: + # model_A: + desc: > + This is SpongeBob's voice, a man's voice, a delightful and distinctive representation of a cute cartoon character. It is characterized by a high-pitched and cheery tone that instantly grabs attention. The voice is incredibly animated and expressive, capturing SpongeBob's bubbly and upbeat personality flawlessly. With each word, it exudes a sense of infectious joy and childlike wonder, creating an immediate connection with the audience. + model_name: ALL - Bob Sponge diff --git a/scripts/convert_from_controlnet.sh b/scripts/convert_from_controlnet.sh new file mode 100644 index 0000000000000000000000000000000000000000..bfeb1bcf1013c5c6802f88e24ae5236ce69915fc --- /dev/null +++ b/scripts/convert_from_controlnet.sh @@ -0,0 +1,8 @@ +ckpt_path=$1 +dump_path=$2 +original_config_path=" " +python python_scripts/convert_original_controlnet_to_diffusers.py \ + --checkpoint_path $ckpt_path \ + --original_config_file "None" \ + --dump_path $dump_path \ + --device cpu \ No newline at end of file diff --git a/scripts/convert_from_safetensors.sh b/scripts/convert_from_safetensors.sh new file mode 100644 index 0000000000000000000000000000000000000000..16828ca3b97fdfb5aa58f36e1eae430ecc9acd73 --- /dev/null +++ b/scripts/convert_from_safetensors.sh @@ -0,0 +1,6 @@ +safetensors_path=$1 +dump_path=$2 +python3 python_scripts/convert_original_stable_diffusion_to_diffusers.py \ + --checkpoint_path $safetensors_path \ + --dump_path $dump_path \ + --from_safetensors diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb226b4ff5f235f9a310ca33cfff996566d9048 --- /dev/null +++ b/utils.py @@ -0,0 +1,300 @@ + +import os +from PIL import Image +import random +import shutil +import datetime +import torchvision.transforms.functional as f +import torch + +from typing import Optional, Tuple +from threading import Lock +from langchain import ConversationChain + +from chat_anything.tts_talker.tts_edge import TTSTalker +from chat_anything.sad_talker.sad_talker import SadTalker +from chat_anything.chatbot.chat import load_chain +from chat_anything.chatbot.select import model_selection_chain +from chat_anything.chatbot.voice_select import voice_selection_chain +import gradio as gr + + +TALKING_HEAD_WIDTH = "350" +sadtalker_checkpoint_path = "MODELS/SadTalker" +config_path = "chat_anything/sad_talker/config" + +class ChatWrapper: + def __init__(self): + self.lock = Lock() + self.sad_talker = SadTalker( + sadtalker_checkpoint_path, config_path, lazy_load=True) + + def __call__( + self, + api_key: str, + inp: str, + history: Optional[Tuple[str, str]], + chain: Optional[ConversationChain], + speak_text: bool, talking_head: bool, + uid: str, + talker : None, + fullbody : str, + ): + """Execute the chat functionality.""" + self.lock.acquire() + if chain is None: + history.append((inp, "Please register with your API key first!")) + else: + try: + print("\n==== date/time: " + str(datetime.datetime.now()) + " ====") + print("inp: " + inp) + print("speak_text: ", speak_text) + print("talking_head: ", talking_head) + history = history or [] + # If chain is None, that is because no API key was provided. + output = "Please paste your OpenAI key from openai.com to use this app. " + \ + str(datetime.datetime.now()) + + output = chain.predict(input=inp).strip() + output = output.replace("\n", "\n\n") + + text_to_display = output + + # #预定义一个talker + # talker = MaleEn() + history.append((inp, text_to_display)) + + html_video, temp_file, html_audio, temp_aud_file = None, None, None, None + if speak_text: + if talking_head: + html_video, temp_file = self.do_html_video_speak( + talker, output, fullbody, uid) + else: + html_audio, temp_aud_file = self.do_html_audio_speak( + talker, output,uid) + else: + if talking_head: + temp_file = os.path.join('tmp', uid, 'videos') + html_video = create_html_video( + temp_file, TALKING_HEAD_WIDTH) + else: + pass + + except Exception as e: + raise e + finally: + self.lock.release() + return history, history, html_video, temp_file, html_audio, temp_aud_file, "" + + + def do_html_audio_speak(self,talker, words_to_speak, uid): + audio_path = os.path.join('tmp', uid, 'audios') + print('uid:', uid, ":", words_to_speak) + audo_file_path = talker.test(text=words_to_speak, audio_path=audio_path) + html_audio = '
no audio
' + try: + temp_aud_file = gr.File(audo_file_path) + print("audio-----------------------------------------------------success") + temp_aud_file_url = "/file=" + temp_aud_file.value['name'] + html_audio = f'' + except IOError as error: + # Could not write to file, exit gracefully + print(error) + return None, None + + return html_audio, audo_file_path + + def do_html_video_speak(self,talker,words_to_speak,fullbody, uid): + if fullbody: + # preprocess='somthing' + preprocess='full' + else: + preprocess='crop' + print("success") + video_path = os.path.join('tmp', uid, 'videos') + if not os.path.exists(video_path): + os.makedirs(video_path) + video_file_path = os.path.join(video_path, 'tempfile.mp4') + _, audio_path = self.do_html_audio_speak( + talker,words_to_speak,uid) + face_file_path = os.path.join('tmp', uid, 'images', 'test.jpg') + + video = self.sad_talker.test(face_file_path, audio_path,preprocess, uid=uid) #video_file_path + print("---------------------------------------------------------success") + print(f"moving {video} -> {video_file_path}") + shutil.move(video, video_file_path) + + return video_file_path, video_file_path + + + def generate_init_face_video(self,class_concept="clock", llm=None,uid=None,fullbody=None, ref_image=None, seed=None): + """ + """ + print('generate concept of', class_concept) + print("=================================================") + print('fullbody:', fullbody) + print('uid:', uid) + print("==================================================") + chain, memory, personality_text = load_chain(llm, class_concept) + model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down + # model_conf, selected_model = model_selection_chain(llm, personality_text, conf_file='resources/models_personality.yaml') # use class concept to choose a generating model, otherwise crack down + voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') + + # added for safe face generation + print('generate concept of', class_concept) + augment_word_list = ["Female ", "female ", "beautiful ", "small ", "cute "] + first_sentence = "Hello, how are you doing today?" + voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') + talker = TTSTalker(selected_voice=selected_voice, gender=voice_conf['gender'], language=voice_conf['language']) + model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') # use class concept to choose a generating model, otherwise crack down + retry_cnt = 4 + if ref_image is None: + face_files = os.listdir(FACE_DIR) + face_img_path = os.path.join(FACE_DIR, random.choice(face_files)) + ref_image = Image.open(face_img_path) + + print('loading face generating model') + anything_facemaker = load_face_generator( + model_dir=model_conf['model_dir'], + lora_path=model_conf['lora_path'], + prompt_template=model_conf['prompt_template'], + negative_prompt=model_conf['negative_prompt'], + ) + retry_cnt = 0 + has_face = anything_facemaker.has_face(ref_image) + init_strength = 1.0 if has_face else 0.85 + strength_retry_step = -0.04 if has_face else 0.04 + while retry_cnt < 8: + try: + generate_face_image( + anything_facemaker, + class_concept, + ref_image, + uid=uid, + strength=init_strength if (retry_cnt==0 and has_face) else init_strength + retry_cnt * strength_retry_step, + controlnet_conditioning_scale=0.5 if retry_cnt == 8 else 0.3, + seed=seed, + ) + self.do_html_video_speak(talker, first_sentence, fullbody, uid=uid) + video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4') + htm_video = create_html_video( + video_file_path, TALKING_HEAD_WIDTH) + break + except Exception as e: + retry_cnt += 1 + class_concept = random.choice(augment_word_list) + class_concept + print(e) + # end of repeat block + + return chain, memory, htm_video, talker + + + def update_talking_head(self, widget, uid, state): + print("success----------------") + if widget: + state = widget + temp_file = os.path.join('tmp', uid, 'videos') + video_html_talking_head = create_html_video( + temp_file, TALKING_HEAD_WIDTH) + return state, video_html_talking_head + else: + return None, "
"
+
+
+def reset_memory(history, memory):
+    memory.clear()
+    history = []
+    return history, history, memory
+            
+
+def create_html_video(file_name, width):
+    return file_name
+
+
+def create_html_audio(file_name):
+    if os.path.exists(file_name):
+        tmp_audio_file = gr.File(file_name, visible=False)
+        tmp_aud_file_url = "/file=" + tmp_audio_file.value['name']
+        html_audio = f''
+        del tmp_aud_file_url
+    else:
+       html_audio = f'' 
+    
+    return html_audio
+
+
+def update_foo(widget, state):
+    if widget:
+        state = widget
+        return state
+
+
+# Pertains to question answering functionality
+def update_use_embeddings(widget, state):
+    if widget:
+        state = widget
+        return state
+
+# This is the code for image generating.
+
+
+def load_face_generator(model_dir, lora_path, prompt_template, negative_prompt):
+    from chat_anything.face_generator.long_prompt_control_generator import LongPromptControlGenerator
+    # # using local
+    model_zoo = "MODELS"
+    face_control_dir = os.path.join(
+        model_zoo, "Face-Landmark-ControlNet", "models_for_diffusers")
+    face_detect_path = os.path.join(
+        model_zoo, "SadTalker", "shape_predictor_68_face_landmarks.dat")
+    # use remote, hugginface auto-download.
+    # use your model path, has to be a model derived from stable diffusion v1-5
+    anything_facemaker = LongPromptControlGenerator(
+        model_dir=model_dir,
+        lora_path=lora_path,
+        prompt_template=prompt_template,
+        negative_prompt=negative_prompt,
+        face_control_dir=face_control_dir,
+        face_detect_path=face_detect_path,
+    )
+    anything_facemaker.load_model(safety_checker=None)
+    return anything_facemaker
+
+
+
+FACE_DIR="resources/images/faces"
+def generate_face_image(
+        anything_facemaker,
+        class_concept, 
+        face_img_pil,
+        uid=None,
+        controlnet_conditioning_scale=1.0,
+        strength=0.95,
+        seed=42,
+    ):
+    face_img_pil = f.center_crop(
+        f.resize(face_img_pil, 512), 512).convert('RGB')
+    prompt = anything_facemaker.prompt_template.format(class_concept)
+    # # There are four ways to generate a image by now.
+    # pure_generate = anything_facemaker.generate(prompt=prompt, image=face_img_pil, do_inversion=False)
+    # inversion = anything_facemaker.generate(prompt=prompt, image=face_img_pil, strength=strength, do_inversion=True)
+
+    print('USING SEED:', seed)
+    generator = torch.Generator(device=anything_facemaker.face_control_pipe.device)
+    generator.manual_seed(seed)
+    if strength is None:
+        pure_control = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=False,
+                                                                 controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
+        init_face_pil = pure_control
+    else:
+        control_inversion = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=True, 
+                                                                 strength=strength,
+                                                                 controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator)
+        init_face_pil = control_inversion
+    print('succeeded generating face image')
+    face_path = os.path.join('tmp', uid, 'images')
+    if not os.path.exists(face_path):
+        os.makedirs(face_path)
+    # TODO: reproduce the images for return, shouldn't use the filesystem
+    face_file_path = os.path.join(face_path, 'test.jpg')
+    init_face_pil.save(face_file_path)
+    return init_face_pil