diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4e0bbdfaa692aff943e039a75fa6203770ea5ae7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +venv/ +.idea/ diff --git a/README.md b/README.md index fb10e2534b090242b9f5d8fe0abdb092a7bd152a..e692edc79d0c9be045aab6fd3e8f9014d7454aec 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,19 @@ pinned: false --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ajay-sainy/Wav2Lip-GFPGAN/blob/main/Wav2Lip-GFPGAN.ipynb) + +Combine Lip Sync AI and Face Restoration AI to get ultra high quality videos. + +[Demo Video](https://www.youtube.com/watch?v=jArkTgAMA4g) + +[![Demo Video](https://img.youtube.com/vi/jArkTgAMA4g/default.jpg)](https://youtu.be/jArkTgAMA4g) + +Projects referred: +1. https://github.com/Rudrabha/Wav2Lip +2. https://github.com/TencentARC/GFPGAN + +Video sources: +1. https://www.youtube.com/watch?v=39w_zYB7AVM&t=0s +2. https://www.youtube.com/watch?v=LQCQym6hVMo&t=0s diff --git a/Wav2Lip-GFPGAN.ipynb b/Wav2Lip-GFPGAN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..97118cbce46c937e88e3a2dd57c91db73070c5d7 --- /dev/null +++ b/Wav2Lip-GFPGAN.ipynb @@ -0,0 +1,208 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Wav2Lip.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/ajay-sainy/Wav2Lip-GFPGAN.git\n", + "basePath = \"/content/Wav2Lip-GFPGAN\"\n", + "%cd {basePath}" + ], + "metadata": { + "id": "YhFe3CJGAIiV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "wav2lipFolderName = 'Wav2Lip-master'\n", + "gfpganFolderName = 'GFPGAN-master'\n", + "wav2lipPath = basePath + '/' + wav2lipFolderName\n", + "gfpganPath = basePath + '/' + gfpganFolderName\n", + "\n", + "!wget 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth' -O {wav2lipPath}'/face_detection/detection/sfd/s3fd.pth'\n", + "!gdown https://drive.google.com/uc?id=1fQtBSYEyuai9MjBOF8j7zZ4oQ9W2N64q --output {wav2lipPath}'/checkpoints/'" + ], + "metadata": { + "id": "mH7A_OaFUs8U" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install -r requirements.txt" + ], + "metadata": { + "id": "CAJqWQS17Qk1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EqX_2YtkUjRI" + }, + "outputs": [], + "source": [ + "import os\n", + "outputPath = basePath+'/outputs'\n", + "inputAudioPath = basePath + '/inputs/kim_audio.mp3'\n", + "inputVideoPath = basePath + '/inputs/kimk_7s_raw.mp4'\n", + "lipSyncedOutputPath = basePath + '/outputs/result.mp4'\n", + "\n", + "if not os.path.exists(outputPath):\n", + " os.makedirs(outputPath)\n", + "\n", + "!cd $wav2lipFolderName && python inference.py \\\n", + "--checkpoint_path checkpoints/wav2lip.pth \\\n", + "--face {inputVideoPath} \\\n", + "--audio {inputAudioPath} \\\n", + "--outfile {lipSyncedOutputPath}" + ] + }, + { + "cell_type": "code", + "source": [ + "!cd $gfpganFolderName && python setup.py develop\n", + "!wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P {gfpganFolderName}'/experiments/pretrained_models'" + ], + "metadata": { + "id": "PPBew5FGGvP9" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import cv2\n", + "from tqdm import tqdm\n", + "from os import path\n", + "\n", + "import os\n", + "\n", + "inputVideoPath = outputPath+'/result.mp4'\n", + "unProcessedFramesFolderPath = outputPath+'/frames'\n", + "\n", + "if not os.path.exists(unProcessedFramesFolderPath):\n", + " os.makedirs(unProcessedFramesFolderPath)\n", + "\n", + "vidcap = cv2.VideoCapture(inputVideoPath)\n", + "numberOfFrames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n", + "fps = vidcap.get(cv2.CAP_PROP_FPS)\n", + "print(\"FPS: \", fps, \"Frames: \", numberOfFrames)\n", + "\n", + "for frameNumber in tqdm(range(numberOfFrames)):\n", + " _,image = vidcap.read()\n", + " cv2.imwrite(path.join(unProcessedFramesFolderPath, str(frameNumber).zfill(4)+'.jpg'), image)\n" + ], + "metadata": { + "id": "X_RNegAcISU2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!cd $gfpganFolderName && \\\n", + " python inference_gfpgan.py -i $unProcessedFramesFolderPath -o $outputPath -v 1.3 -s 2 --only_center_face --bg_upsampler None" + ], + "metadata": { + "id": "k6krjfxTJYlu" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "restoredFramesPath = outputPath + '/restored_imgs/'\n", + "processedVideoOutputPath = outputPath\n", + "\n", + "dir_list = os.listdir(restoredFramesPath)\n", + "dir_list.sort()\n", + "\n", + "import cv2\n", + "import numpy as np\n", + "\n", + "batch = 0\n", + "batchSize = 300\n", + "from tqdm import tqdm\n", + "for i in tqdm(range(0, len(dir_list), batchSize)):\n", + " img_array = []\n", + " start, end = i, i+batchSize\n", + " print(\"processing \", start, end)\n", + " for filename in tqdm(dir_list[start:end]):\n", + " filename = restoredFramesPath+filename;\n", + " img = cv2.imread(filename)\n", + " if img is None:\n", + " continue\n", + " height, width, layers = img.shape\n", + " size = (width,height)\n", + " img_array.append(img)\n", + "\n", + "\n", + " out = cv2.VideoWriter(processedVideoOutputPath+'/batch_'+str(batch).zfill(4)+'.avi',cv2.VideoWriter_fourcc(*'DIVX'), 30, size)\n", + " batch = batch + 1\n", + " \n", + " for i in range(len(img_array)):\n", + " out.write(img_array[i])\n", + " out.release()\n" + ], + "metadata": { + "id": "XibzGPIVJfvP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "concatTextFilePath = outputPath + \"/concat.txt\"\n", + "concatTextFile=open(concatTextFilePath,\"w\")\n", + "for ips in range(batch):\n", + " concatTextFile.write(\"file batch_\" + str(ips).zfill(4) + \".avi\\n\")\n", + "concatTextFile.close()\n", + "\n", + "concatedVideoOutputPath = outputPath + \"/concated_output.avi\"\n", + "!ffmpeg -y -f concat -i {concatTextFilePath} -c copy {concatedVideoOutputPath} \n", + "\n", + "finalProcessedOuputVideo = processedVideoOutputPath+'/final_with_audio.avi'\n", + "!ffmpeg -y -i {concatedVideoOutputPath} -i {inputAudioPath} -map 0 -map 1:a -c:v copy -shortest {finalProcessedOuputVideo}\n", + "\n", + "from google.colab import files\n", + "files.download(finalProcessedOuputVideo)" + ], + "metadata": { + "id": "jtde28qwpDd6" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a4f308bbfd5078e7cee63131f1d3667ea5ae04 --- /dev/null +++ b/app.py @@ -0,0 +1,66 @@ +import os + +import gradio as gr +import subprocess +from subprocess import call + +basePath = os.path.dirname(os.path.realpath(__file__)) + +outputPath = os.path.join(basePath, 'outputs') +inputAudioPath = basePath + '/inputs/kim_audio.mp3' +inputVideoPath = basePath + '/inputs/kimk_7s_raw.mp4' +lipSyncedOutputPath = basePath + '/outputs/result.mp4' + +with gr.Blocks() as ui: + with gr.Row(): + video = gr.File(label="Video or Image", info="Filepath of video/image that contains faces to use") + audio = gr.File(label="Audio", info="Filepath of video/audio file to use as raw audio source") + with gr.Column(): + checkpoint = gr.Radio(["wav2lip", "wav2lip_gan"], label="Checkpoint", + info="Name of saved checkpoint to load weights from") + no_smooth = gr.Checkbox(label="No Smooth", + info="Prevent smoothing face detections over a short temporal window") + resize_factor = gr.Slider(minimum=1, maximum=4, step=1, label="Resize Factor", + info="Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p") + with gr.Row(): + with gr.Column(): + pad_top = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Top", info="Padding above") + pad_bottom = gr.Slider(minimum=0, maximum=50, step=1, value=10, + label="Pad Bottom (Often increasing this to 20 allows chin to be included)", + info="Padding below lips") + pad_left = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Left", + info="Padding to the left of lips") + pad_right = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Pad Right", + info="Padding to the right of lips") + generate_btn = gr.Button("Generate") + with gr.Column(): + result = gr.Video() + + + def generate(video, audio, checkpoint, no_smooth, resize_factor, pad_top, pad_bottom, pad_left, pad_right): + if video is None or audio is None or checkpoint is None: + return + + smooth = "--nosmooth" if no_smooth else "" + + cmd = [ + "python", + "inference.py", + "--checkpoint_path", f"checkpoints/{checkpoint}.pth", + "--segmentation_path", "checkpoints/face_segmentation.pth", + "--enhance_face", "gfpgan", + "--face", video.name, + "--audio", audio.name, + "--outfile", "results/output.mp4", + ] + + call(cmd) + return "results/output.mp4" + + + generate_btn.click( + generate, + [video, audio, checkpoint, pad_top, pad_bottom, pad_left, pad_right, resize_factor], + result) + +ui.queue().launch(debug=True) diff --git a/gfpgan/.gitignore b/gfpgan/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8151890ed0f735bc3db37b3900616e1657dee170 --- /dev/null +++ b/gfpgan/.gitignore @@ -0,0 +1,139 @@ +# ignored folders +datasets/* +experiments/* +results/* +tb_logger/* +wandb/* +tmp/* + +version.py + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/gfpgan/.pre-commit-config.yaml b/gfpgan/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d221d29fbaac74bef1c0cd910ce8d8b6526181b8 --- /dev/null +++ b/gfpgan/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +repos: + # flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + args: ["--config=setup.cfg", "--ignore=W504, W503"] + + # modify known_third_party + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + + # isort + - repo: https://github.com/timothycrosley/isort + rev: 5.2.2 + hooks: + - id: isort + + # yapf + - repo: https://github.com/pre-commit/mirrors-yapf + rev: v0.30.0 + hooks: + - id: yapf + + # codespell + - repo: https://github.com/codespell-project/codespell + rev: v2.1.0 + hooks: + - id: codespell + + # pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace # Trim trailing whitespace + - id: check-yaml # Attempt to load all yaml files to verify syntax + - id: check-merge-conflict # Check for files that contain merge conflict strings + - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings + - id: end-of-file-fixer # Make sure files end in a newline and only a newline + - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 + - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- + args: ["--remove"] + - id: mixed-line-ending # Replace or check mixed line ending + args: ["--fix=lf"] diff --git a/gfpgan/CODE_OF_CONDUCT.md b/gfpgan/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..e8cc4daa4345590464314889b187d6a2d7a8e20f --- /dev/null +++ b/gfpgan/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +xintao.wang@outlook.com or xintaowang@tencent.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/gfpgan/Comparisons.md b/gfpgan/Comparisons.md new file mode 100644 index 0000000000000000000000000000000000000000..1542d4c0c0a04ceeba42f24d5277351c2514214a --- /dev/null +++ b/gfpgan/Comparisons.md @@ -0,0 +1,24 @@ +# Comparisons + +## Comparisons among different model versions + +Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs. + +| Version | Strengths | Weaknesses | +| :---: | :---: | :---: | +|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity | +|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural| + +For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size. + +| Input | V1 | V1.2 | V1.3 +| :---: | :---: | :---: | :---: | +|![019_Anne_Hathaway_01_00](https://user-images.githubusercontent.com/17445847/153762146-96b25999-4ddd-42a5-a3fe-bb90565f4c4f.png)| ![](https://user-images.githubusercontent.com/17445847/153762256-ef41e749-5a27-495c-8a9c-d8403be55869.png) | ![](https://user-images.githubusercontent.com/17445847/153762297-d41582fc-6253-4e7e-a1ce-4dc237ae3bf3.png) | ![](https://user-images.githubusercontent.com/17445847/153762215-e0535e94-b5ba-426e-97b5-35c00873604d.png) | +| ![106_Harry_Styles_00_00](https://user-images.githubusercontent.com/17445847/153789040-632c0eda-c15a-43e9-a63c-9ead64f92d4a.png) | ![](https://user-images.githubusercontent.com/17445847/153789172-93cd4980-5318-4633-a07e-1c8f8064ff89.png) | ![](https://user-images.githubusercontent.com/17445847/153789185-f7b268a7-d1db-47b0-ae4a-335e5d657a18.png) | ![](https://user-images.githubusercontent.com/17445847/153789198-7c7f3bca-0ef0-4494-92f0-20aa6f7d7464.png)| +| ![076_Paris_Hilton_00_00](https://user-images.githubusercontent.com/17445847/153789607-86387770-9db8-441f-b08a-c9679b121b85.png) | ![](https://user-images.githubusercontent.com/17445847/153789619-e56b438a-78a0-425d-8f44-ec4692a43dda.png) | ![](https://user-images.githubusercontent.com/17445847/153789633-5b28f778-3b7f-4e08-8a1d-740ca6e82d8a.png) | ![](https://user-images.githubusercontent.com/17445847/153789645-bc623f21-b32d-4fc3-bfe9-61203407a180.png)| +| ![008_George_Clooney_00_00](https://user-images.githubusercontent.com/17445847/153790017-0c3ca94d-1c9d-4a0e-b539-ab12d4da98ff.png) | ![](https://user-images.githubusercontent.com/17445847/153790028-fb0d38ab-399d-4a30-8154-2dcd72ca90e8.png) | ![](https://user-images.githubusercontent.com/17445847/153790044-1ef68e34-6120-4439-a5d9-0b6cdbe9c3d0.png) | ![](https://user-images.githubusercontent.com/17445847/153790059-a8d3cece-8989-4e9a-9ffe-903e1690cfd6.png)| +| ![057_Madonna_01_00](https://user-images.githubusercontent.com/17445847/153790624-2d0751d0-8fb4-4806-be9d-71b833c2c226.png) | ![](https://user-images.githubusercontent.com/17445847/153790639-7eb870e5-26b2-41dc-b139-b698bb40e6e6.png) | ![](https://user-images.githubusercontent.com/17445847/153790651-86899b7a-a1b6-4242-9e8a-77b462004998.png) | ![](https://user-images.githubusercontent.com/17445847/153790655-c8f6c25b-9b4e-4633-b16f-c43da86cff8f.png)| +| ![044_Amy_Schumer_01_00](https://user-images.githubusercontent.com/17445847/153790811-3fb4fc46-5b4f-45fe-8fcb-a128de2bfa60.png) | ![](https://user-images.githubusercontent.com/17445847/153790817-d45aa4ff-bfc4-4163-b462-75eef9426fab.png) | ![](https://user-images.githubusercontent.com/17445847/153790824-5f93c3a0-fe5a-42f6-8b4b-5a5de8cd0ac3.png) | ![](https://user-images.githubusercontent.com/17445847/153790835-0edf9944-05c7-41c4-8581-4dc5ffc56c9d.png)| +| ![012_Jackie_Chan_01_00](https://user-images.githubusercontent.com/17445847/153791176-737b016a-e94f-4898-8db7-43e7762141c9.png) | ![](https://user-images.githubusercontent.com/17445847/153791183-2f25a723-56bf-4cd5-aafe-a35513a6d1c5.png) | ![](https://user-images.githubusercontent.com/17445847/153791194-93416cf9-2b58-4e70-b806-27e14c58d4fd.png) | ![](https://user-images.githubusercontent.com/17445847/153791202-aa98659c-b702-4bce-9c47-a2fa5eccc5ae.png)| + + diff --git a/gfpgan/FAQ.md b/gfpgan/FAQ.md new file mode 100644 index 0000000000000000000000000000000000000000..e4d5a49cc216ffe987c7ab195a430f463f375425 --- /dev/null +++ b/gfpgan/FAQ.md @@ -0,0 +1,7 @@ +# FAQ + +1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model** + +**A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying. +2) This model is not directly trained. Instead, it is converted from another *bilinear* model. +3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion. diff --git a/gfpgan/LICENSE b/gfpgan/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..24384c0728442ace20d180d784cb3ea714413923 --- /dev/null +++ b/gfpgan/LICENSE @@ -0,0 +1,351 @@ +Tencent is pleased to support the open source community by making GFPGAN available. + +Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. + +GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below. + + +Terms of the Apache License Version 2.0: +--------------------------------------------- +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +1. Definitions. + +“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License. + +“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.” + +“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and + +You must cause any modified files to carry prominent notices stating that You changed the files; and + +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + +If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + + + +Other dependencies and licenses: + + +Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. basicsr +Copyright 2018-2020 BasicSR Authors + + +This BasicSR project is released under the Apache 2.0 license. + +A copy of Apache 2.0 is included in this file. + +StyleGAN2 +The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch. +The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license. +DFDNet +The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. + +Terms of the Nvidia License: +--------------------------------------------- + +1. Definitions + +"Licensor" means any person or entity that distributes its Work. + +"Software" means the original work of authorship made available under +this License. + +"Work" means the Software and any additions to or derivative works of +the Software that are made available under this License. + +"Nvidia Processors" means any central processing unit (CPU), graphics +processing unit (GPU), field-programmable gate array (FPGA), +application-specific integrated circuit (ASIC) or any combination +thereof designed, made, sold, or provided by Nvidia or its affiliates. + +The terms "reproduce," "reproduction," "derivative works," and +"distribution" have the meaning as provided under U.S. copyright law; +provided, however, that for the purposes of this License, derivative +works shall not include works that remain separable from, or merely +link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are "made available" under this License +by including in or with the Work either (a) a copyright notice +referencing the applicability of this License to the Work, or (b) a +copy of this License. + +2. License Grants + + 2.1 Copyright Grant. Subject to the terms and conditions of this + License, each Licensor grants to you a perpetual, worldwide, + non-exclusive, royalty-free, copyright license to reproduce, + prepare derivative works of, publicly display, publicly perform, + sublicense and distribute its Work and any resulting derivative + works in any form. + +3. Limitations + + 3.1 Redistribution. You may reproduce or distribute the Work only + if (a) you do so under this License, (b) you include a complete + copy of this License with your distribution, and (c) you retain + without modification any copyright, patent, trademark, or + attribution notices that are present in the Work. + + 3.2 Derivative Works. You may specify that additional or different + terms apply to the use, reproduction, and distribution of your + derivative works of the Work ("Your Terms") only if (a) Your Terms + provide that the use limitation in Section 3.3 applies to your + derivative works, and (b) you identify the specific derivative + works that are subject to Your Terms. Notwithstanding Your Terms, + this License (including the redistribution requirements in Section + 3.1) will continue to apply to the Work itself. + + 3.3 Use Limitation. The Work and any derivative works thereof only + may be used or intended for use non-commercially. The Work or + derivative works thereof may be used or intended for use by Nvidia + or its affiliates commercially or non-commercially. As used herein, + "non-commercially" means for research or evaluation purposes only. + + 3.4 Patent Claims. If you bring or threaten to bring a patent claim + against any Licensor (including any claim, cross-claim or + counterclaim in a lawsuit) to enforce any patents that you allege + are infringed by any Work, then your rights under this License from + such Licensor (including the grants in Sections 2.1 and 2.2) will + terminate immediately. + + 3.5 Trademarks. This License does not grant any rights to use any + Licensor's or its affiliates' names, logos, or trademarks, except + as necessary to reproduce the notices described in this License. + + 3.6 Termination. If you violate any term of this License, then your + rights under this License (including the grants in Sections 2.1 and + 2.2) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR +NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER +THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL +THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE +SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF +OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK +(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, +LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER +COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +MIT License + +Copyright (c) 2019 Kim Seonghyeon + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + + +Open Source Software licensed under the BSD 3-Clause license: +--------------------------------------------- +1. torchvision +Copyright (c) Soumith Chintala 2016, +All rights reserved. + +2. torch +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + + +Terms of the BSD 3-Clause License: +--------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. numpy +Copyright (c) 2005-2020, NumPy Developers. +All rights reserved. + +A copy of BSD 3-Clause License is included in this file. + +The NumPy repository and source distributions bundle several libraries that are +compatibly licensed. We list these here. + +Name: Numpydoc +Files: doc/sphinxext/numpydoc/* +License: BSD-2-Clause + For details, see doc/sphinxext/LICENSE.txt + +Name: scipy-sphinx-theme +Files: doc/scipy-sphinx-theme/* +License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0 + For details, see doc/scipy-sphinx-theme/LICENSE.txt + +Name: lapack-lite +Files: numpy/linalg/lapack_lite/* +License: BSD-3-Clause + For details, see numpy/linalg/lapack_lite/LICENSE.txt + +Name: tempita +Files: tools/npy_tempita/* +License: MIT + For details, see tools/npy_tempita/license.txt + +Name: dragon4 +Files: numpy/core/src/multiarray/dragon4.c +License: MIT + For license text, see numpy/core/src/multiarray/dragon4.c + + + +Open Source Software licensed under the MIT license: +--------------------------------------------- +1. facexlib +Copyright (c) 2020 Xintao Wang + +2. opencv-python +Copyright (c) Olli-Pekka Heinisuo +Please note that only files in cv2 package are used. + + +Terms of the MIT License: +--------------------------------------------- +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + +Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein: +--------------------------------------------- +1. tqdm +Copyright (c) 2013 noamraph + +`tqdm` is a product of collaborative work. +Unless otherwise stated, all authors (see commit logs) retain copyright +for their respective work, and release the work under the MIT licence +(text below). + +Exceptions or notable authors are listed below +in reverse chronological order: + +* files: * + MPLv2.0 2015-2020 (c) Casper da Costa-Luis + [casperdcl](https://github.com/casperdcl). +* files: tqdm/_tqdm.py + MIT 2016 (c) [PR #96] on behalf of Google Inc. +* files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore + MIT 2013 (c) Noam Yorav-Raphael, original author. + +[PR #96]: https://github.com/tqdm/tqdm/pull/96 + + +Mozilla Public Licence (MPL) v. 2.0 - Exhibit A +----------------------------------------------- + +This Source Code Form is subject to the terms of the +Mozilla Public License, v. 2.0. +If a copy of the MPL was not distributed with this file, +You can obtain one at https://mozilla.org/MPL/2.0/. + + +MIT License (MIT) +----------------- + +Copyright (c) 2013 noamraph + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/gfpgan/MANIFEST.in b/gfpgan/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..bcaa7179b82f6f0eebace30fa7e4ebea88408f52 --- /dev/null +++ b/gfpgan/MANIFEST.in @@ -0,0 +1,8 @@ +include assets/* +include inputs/* +include scripts/*.py +include inference_gfpgan.py +include VERSION +include LICENSE +include requirements.txt +include gfpgan/weights/README.md diff --git a/gfpgan/PaperModel.md b/gfpgan/PaperModel.md new file mode 100644 index 0000000000000000000000000000000000000000..e9c8bdc4e757a9818f18d1926b7452172486ec92 --- /dev/null +++ b/gfpgan/PaperModel.md @@ -0,0 +1,76 @@ +# Installation + +We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions. See [here](README.md#installation) for this easier installation.
+If you want want to use the original model in our paper, please follow the instructions below. + +1. Clone repo + + ```bash + git clone https://github.com/xinntao/GFPGAN.git + cd GFPGAN + ``` + +1. Install dependent packages + + As StyleGAN2 uses customized PyTorch C++ extensions, you need to **compile them during installation** or **load them just-in-time(JIT)**. + You can refer to [BasicSR-INSTALL.md](https://github.com/xinntao/BasicSR/blob/master/INSTALL.md) for more details. + + **Option 1: Load extensions just-in-time(JIT)** (For those just want to do simple inferences, may have less issues) + + ```bash + # Install basicsr - https://github.com/xinntao/BasicSR + # We use BasicSR for both training and inference + pip install basicsr + + # Install facexlib - https://github.com/xinntao/facexlib + # We use face detection and face restoration helper in the facexlib package + pip install facexlib + + pip install -r requirements.txt + python setup.py develop + + # remember to set BASICSR_JIT=True before your running commands + ``` + + **Option 2: Compile extensions during installation** (For those need to train/inference for many times) + + ```bash + # Install basicsr - https://github.com/xinntao/BasicSR + # We use BasicSR for both training and inference + # Set BASICSR_EXT=True to compile the cuda extensions in the BasicSR - It may take several minutes to compile, please be patient + # Add -vvv for detailed log prints + BASICSR_EXT=True pip install basicsr -vvv + + # Install facexlib - https://github.com/xinntao/facexlib + # We use face detection and face restoration helper in the facexlib package + pip install facexlib + + pip install -r requirements.txt + python setup.py develop + ``` + +## :zap: Quick Inference + +Download pre-trained models: [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) + +```bash +wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth -P experiments/pretrained_models +``` + +- Option 1: Load extensions just-in-time(JIT) + + ```bash + BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 + + # for aligned images + BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned + ``` + +- Option 2: Have successfully compiled extensions during installation + + ```bash + python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 + + # for aligned images + python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned + ``` diff --git a/gfpgan/README.md b/gfpgan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b6296252fe04dae9116817552c3620ea937b7287 --- /dev/null +++ b/gfpgan/README.md @@ -0,0 +1,192 @@ +

+ +

+ +##
English | 简体中文
+ +[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases) +[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/) +[![Open issue](https://img.shields.io/github/issues/TencentARC/GFPGAN)](https://github.com/TencentARC/GFPGAN/issues) +[![Closed issue](https://img.shields.io/github/issues-closed/TencentARC/GFPGAN)](https://github.com/TencentARC/GFPGAN/issues) +[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE) +[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml) +[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml) + +1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN google colab logo; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model) +2. Online demo: [Huggingface](https://huggingface.co/spaces/akhaliq/GFPGAN) (return only the cropped face) +3. Online demo: [Replicate.ai](https://replicate.com/xinntao/gfpgan) (may need to sign in, return the whole image) +4. Online demo: [Baseten.co](https://app.baseten.co/applications/Q04Lz0d/operator_views/8qZG6Bg) (backed by GPU, returns the whole image) +5. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**. + +> :rocket: **Thanks for your interest in our work. You may also want to check our new updates on the *tiny models* for *anime images and videos* in [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN/blob/master/docs/anime_video_model.md)** :blush: + +GFPGAN aims at developing a **Practical Algorithm for Real-world Face Restoration**.
+It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration. + +:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md). + +:triangular_flag_on_post: **Updates** + +- :fire::fire::white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md) +- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN). +- :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN). +- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions. +- :white_check_mark: We provide an updated model without colorizing faces. + +--- + +If GFPGAN is helpful in your photos/projects, please help to :star: this repo or recommend it to your friends. Thanks:blush: +Other recommended projects:
+:arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration
+:arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR): An open-source image and video restoration toolbox
+:arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions
+:arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison
+ +--- + +### :book: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior + +> [[Paper](https://arxiv.org/abs/2101.04061)]   [[Project Page](https://xinntao.github.io/projects/gfpgan)]   [Demo]
+> [Xintao Wang](https://xinntao.github.io/), [Yu Li](https://yu-li.github.io/), [Honglun Zhang](https://scholar.google.com/citations?hl=en&user=KjQLROoAAAAJ), [Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)
+> Applied Research Center (ARC), Tencent PCG + +

+ +

+ +--- + +## :wrench: Dependencies and Installation + +- Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) +- [PyTorch >= 1.7](https://pytorch.org/) +- Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) +- Option: Linux + +### Installation + +We now provide a *clean* version of GFPGAN, which does not require customized CUDA extensions.
+If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation. + +1. Clone repo + + ```bash + git clone https://github.com/TencentARC/GFPGAN.git + cd GFPGAN + ``` + +1. Install dependent packages + + ```bash + # Install basicsr - https://github.com/xinntao/BasicSR + # We use BasicSR for both training and inference + pip install basicsr + + # Install facexlib - https://github.com/xinntao/facexlib + # We use face detection and face restoration helper in the facexlib package + pip install facexlib + + pip install -r requirements.txt + python setup.py develop + + # If you want to enhance the background (non-face) regions with Real-ESRGAN, + # you also need to install the realesrgan package + pip install realesrgan + ``` + +## :zap: Quick Inference + +We take the v1.3 version for an example. More models can be found [here](#european_castle-model-zoo). + +Download pre-trained models: [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) + +```bash +wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models +``` + +**Inference!** + +```bash +python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 +``` + +```console +Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]... + + -h show this help + -i input Input image or folder. Default: inputs/whole_imgs + -o output Output folder. Default: results + -v version GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3 + -s upscale The final upsampling scale of the image. Default: 2 + -bg_upsampler background upsampler. Default: realesrgan + -bg_tile Tile size for background sampler, 0 for no tile during testing. Default: 400 + -suffix Suffix of the restored faces + -only_center_face Only restore the center face + -aligned Input are aligned faces + -ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto +``` + +If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference. + +## :european_castle: Model Zoo + +| Version | Model Name | Description | +| :---: | :---: | :---: | +| V1.3 | [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) | Based on V1.2; **more natural** restoration results; better results on very low-quality / high-quality inputs. | +| V1.2 | [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) | No colorization; no CUDA extensions are required. Trained with more data with pre-processing. | +| V1 | [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) | The paper model, with colorization. | + +The comparisons are in [Comparisons.md](Comparisons.md). + +Note that V1.3 is not always better than V1.2. You may need to select different models based on your purpose and inputs. + +| Version | Strengths | Weaknesses | +| :---: | :---: | :---: | +|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity | +|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural | + +You can find **more models (such as the discriminators)** here: [[Google Drive](https://drive.google.com/drive/folders/17rLiFzcUMoQuhLnptDsKolegHWwJOnHu?usp=sharing)], OR [[Tencent Cloud 腾讯微云](https://share.weiyun.com/ShYoCCoc)] + +## :computer: Training + +We provide the training codes for GFPGAN (used in our paper).
+You could improve it according to your own needs. + +**Tips** + +1. More high quality faces can improve the restoration quality. +2. You may need to perform some pre-processing, such as beauty makeup. + +**Procedures** + +(You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.) + +1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset) + +1. Download pre-trained models and other data. Put them in the `experiments/pretrained_models` folder. + 1. [Pre-trained StyleGAN2 model: StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth) + 1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth) + 1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth) + +1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly. + +1. Training + +> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch + +## :scroll: License and Acknowledgement + +GFPGAN is released under Apache License Version 2.0. + +## BibTeX + + @InProceedings{wang2021gfpgan, + author = {Xintao Wang and Yu Li and Honglun Zhang and Ying Shan}, + title = {Towards Real-World Blind Face Restoration with Generative Facial Prior}, + booktitle={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2021} + } + +## :e-mail: Contact + +If you have any question, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`. diff --git a/gfpgan/README_CN.md b/gfpgan/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..880f20631cb199fb33e1541ea7fd38bf1d29167b --- /dev/null +++ b/gfpgan/README_CN.md @@ -0,0 +1,7 @@ +

+ +

+ +##
English | 简体中文
+ +还未完工,欢迎贡献! diff --git a/gfpgan/VERSION b/gfpgan/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..d0149fef743a8035720ed161412709e87702dcab --- /dev/null +++ b/gfpgan/VERSION @@ -0,0 +1 @@ +1.3.4 diff --git a/gfpgan/assets/gfpgan_logo.png b/gfpgan/assets/gfpgan_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..f01937838faf7689869d3a4dfd50da006af8fd5d Binary files /dev/null and b/gfpgan/assets/gfpgan_logo.png differ diff --git a/gfpgan/experiments/pretrained_models/GFPGANv1.3.pth b/gfpgan/experiments/pretrained_models/GFPGANv1.3.pth new file mode 100644 index 0000000000000000000000000000000000000000..1da748a3ef84ff85dd2c77c836f222aae22b007e --- /dev/null +++ b/gfpgan/experiments/pretrained_models/GFPGANv1.3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70 +size 348632874 diff --git a/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth b/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth new file mode 100644 index 0000000000000000000000000000000000000000..afedb5c7e826056840c9cc183f2c6f0186fd17ba --- /dev/null +++ b/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad +size 348632874 diff --git a/gfpgan/gfpgan/__init__.py b/gfpgan/gfpgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94daaeebce5604d61999f0b1b354b9a9e299b991 --- /dev/null +++ b/gfpgan/gfpgan/__init__.py @@ -0,0 +1,7 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .utils import * + +# from .version import * diff --git a/gfpgan/gfpgan/archs/__init__.py b/gfpgan/gfpgan/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec5f17bfa38729b55f57cae8e40c27310db2b7b --- /dev/null +++ b/gfpgan/gfpgan/archs/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import arch modules for registry +# scan all the files that end with '_arch.py' under the archs folder +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] diff --git a/gfpgan/gfpgan/archs/arcface_arch.py b/gfpgan/gfpgan/archs/arcface_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..e6d3bd97f83334450bd78ad2c3b9871102a56b70 --- /dev/null +++ b/gfpgan/gfpgan/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x diff --git a/gfpgan/gfpgan/archs/gfpgan_bilinear_arch.py b/gfpgan/gfpgan/archs/gfpgan_bilinear_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..52e0de88de8543cf4afdc3988c4cdfc7c7060687 --- /dev/null +++ b/gfpgan/gfpgan/archs/gfpgan_bilinear_arch.py @@ -0,0 +1,312 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn + +from .gfpganv1_arch import ResUpBlock +from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2GeneratorBilinear) + + +class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorBilinearSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorBilinearSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +@ARCH_REGISTRY.register() +class GFPGANBilinear(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: GFPGANv1Clean. + + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANBilinear, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANBilinear. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/gfpgan/gfpgan/archs/gfpganv1_arch.py b/gfpgan/gfpgan/archs/gfpganv1_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..e092b4f7633dece505e5cd3bac4a482df3746654 --- /dev/null +++ b/gfpgan/gfpgan/archs/gfpganv1_arch.py @@ -0,0 +1,439 @@ +import math +import random +import torch +from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2Generator) +from basicsr.ops.fused_act import FusedLeakyReLU +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class StyleGAN2GeneratorSFT(StyleGAN2Generator): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ConvUpLayer(nn.Module): + """Convolutional upsampling layer. It uses bilinear upsampler + Conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + bias=True, + bias_init_val=0, + activate=True): + super(ConvUpLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + # self.scale is used to scale the convolution weights, which is related to the common initializations. + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + + if bias and not activate: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + # activation + if activate: + if bias: + self.activation = FusedLeakyReLU(out_channels) + else: + self.activation = ScaledLeakyReLU(0.2) + else: + self.activation = None + + def forward(self, x): + # bilinear upsample + out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + # conv + out = F.conv2d( + out, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + # activation + if self.activation is not None: + out = self.activation(out) + return out + + +class ResUpBlock(nn.Module): + """Residual block with upsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels): + super(ResUpBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True) + self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be + applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1). + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + resample_kernel=resample_kernel, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANv1. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs + + +@ARCH_REGISTRY.register() +class FacialComponentDiscriminator(nn.Module): + """Facial component (eyes, mouth, noise) discriminator used in GFPGAN. + """ + + def __init__(self): + super(FacialComponentDiscriminator, self).__init__() + # It now uses a VGG-style architectrue with fixed model size + self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) + self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) + + def forward(self, x, return_feats=False): + """Forward function for FacialComponentDiscriminator. + + Args: + x (Tensor): Input images. + return_feats (bool): Whether to return intermediate features. Default: False. + """ + feat = self.conv1(x) + feat = self.conv3(self.conv2(feat)) + rlt_feats = [] + if return_feats: + rlt_feats.append(feat.clone()) + feat = self.conv5(self.conv4(feat)) + if return_feats: + rlt_feats.append(feat.clone()) + out = self.final_conv(feat) + + if return_feats: + return out, rlt_feats + else: + return out, None diff --git a/gfpgan/gfpgan/archs/gfpganv1_clean_arch.py b/gfpgan/gfpgan/archs/gfpganv1_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2e15d288bf0ad641034ed58d5dab37b0baabb3 --- /dev/null +++ b/gfpgan/gfpgan/archs/gfpganv1_clean_arch.py @@ -0,0 +1,324 @@ +import math +import random +import torch +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + +from .stylegan2_clean_arch import StyleGAN2GeneratorClean + + +class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False): + super(StyleGAN2GeneratorCSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorCSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ResBlock(nn.Module): + """Residual block with bilinear upsampling/downsampling. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + mode (str): Upsampling/downsampling mode. Options: down | up. Default: down. + """ + + def __init__(self, in_channels, out_channels, mode='down'): + super(ResBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == 'down': + self.scale_factor = 0.5 + elif mode == 'up': + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + skip = self.skip(x) + out = out + skip + return out + + +@ARCH_REGISTRY.register() +class GFPGANv1Clean(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the clean version without custom compiled CUDA extensions used in StyleGAN2. + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANv1Clean, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) + in_channels = out_channels + + self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorCSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + self.condition_shift.append( + nn.Sequential( + nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), + nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANv1Clean. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/gfpgan/gfpgan/archs/stylegan2_bilinear_arch.py b/gfpgan/gfpgan/archs/stylegan2_bilinear_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..1342ee3c9a6b8f742fb76ce7d5b907cd39fbc350 --- /dev/null +++ b/gfpgan/gfpgan/archs/stylegan2_bilinear_arch.py @@ -0,0 +1,613 @@ +import math +import random +import torch +from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class EqualLinear(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Size of each sample. + out_channels (int): Size of each output sample. + bias (bool): If set to ``False``, the layer will not learn an additive + bias. Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + lr_mul (float): Learning rate multiplier. Default: 1. + activation (None | str): The activation after ``linear`` operation. + Supported: 'fused_lrelu', None. Default: None. + """ + + def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None): + super(EqualLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lr_mul = lr_mul + self.activation = activation + if self.activation not in ['fused_lrelu', None]: + raise ValueError(f'Wrong activation value in EqualLinear: {activation}' + "Supported ones are: ['fused_lrelu', None].") + self.scale = (1 / math.sqrt(in_channels)) * lr_mul + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + if self.bias is None: + bias = None + else: + bias = self.bias * self.lr_mul + if self.activation == 'fused_lrelu': + out = F.linear(x, self.weight * self.scale) + out = fused_leaky_relu(out, bias) + else: + out = F.linear(x, self.weight * self.scale, bias=bias) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, bias={self.bias is not None})') + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. + Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + eps (float): A value added to the denominator for numerical stability. + Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8, + interpolation_mode='bilinear'): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + # modulation inside each modulated conv + self.modulation = EqualLinear( + num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None) + + self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.scale * self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, ' + f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode='bilinear'): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=demodulate, + sample_mode=sample_mode, + interpolation_mode=interpolation_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.activate = FusedLeakyReLU(out_channels) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # activation (with bias) + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'): + super(ToRGB, self).__init__() + self.upsample = upsample + self.interpolation_mode = interpolation_mode + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + self.modulated_conv = ModulatedConv2d( + in_channels, + 3, + kernel_size=1, + num_style_feat=num_style_feat, + demodulate=False, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate( + skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorBilinear(nn.Module): + """StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of + StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + interpolation_mode='bilinear'): + super(StyleGAN2GeneratorBilinear, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.append( + EqualLinear( + num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, + activation='fused_lrelu')) + self.style_mlp = nn.Sequential(*style_mlp_layers) + + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample', + interpolation_mode=interpolation_mode)) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None, + interpolation_mode=interpolation_mode)) + self.to_rgbs.append( + ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2Generator. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. + Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is + False. Default: True. + truncation (float): TODO. Default: 1. + truncation_latent (Tensor | None): TODO. Default: None. + inject_index (int | None): The injection index for mixing noise. + Default: None. + return_latents (bool): Whether to return style latents. + Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latent with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +class ScaledLeakyReLU(nn.Module): + """Scaled LeakyReLU. + + Args: + negative_slope (float): Negative slope. Default: 0.2. + """ + + def __init__(self, negative_slope=0.2): + super(ScaledLeakyReLU, self).__init__() + self.negative_slope = negative_slope + + def forward(self, x): + out = F.leaky_relu(x, negative_slope=self.negative_slope) + return out * math.sqrt(2) + + +class EqualConv2d(nn.Module): + """Equalized Linear as StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + stride (int): Stride of the convolution. Default: 1 + padding (int): Zero-padding added to both sides of the input. + Default: 0. + bias (bool): If ``True``, adds a learnable bias to the output. + Default: ``True``. + bias_init_val (float): Bias initialized value. Default: 0. + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0): + super(EqualConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) + else: + self.register_parameter('bias', None) + + def forward(self, x): + out = F.conv2d( + x, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size},' + f' stride={self.stride}, padding={self.padding}, ' + f'bias={self.bias is not None})') + + +class ConvLayer(nn.Sequential): + """Conv Layer used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Kernel size. + downsample (bool): Whether downsample by a factor of 2. + Default: False. + bias (bool): Whether with bias. Default: True. + activate (bool): Whether use activateion. Default: True. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + downsample=False, + bias=True, + activate=True, + interpolation_mode='bilinear'): + layers = [] + self.interpolation_mode = interpolation_mode + # downsample + if downsample: + if self.interpolation_mode == 'nearest': + self.align_corners = None + else: + self.align_corners = False + + layers.append( + torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners)) + stride = 1 + self.padding = kernel_size // 2 + # conv + layers.append( + EqualConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias + and not activate)) + # activation + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channels)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super(ConvLayer, self).__init__(*layers) + + +class ResBlock(nn.Module): + """Residual block used in StyleGAN2 Discriminator. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + """ + + def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'): + super(ResBlock, self).__init__() + + self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) + self.conv2 = ConvLayer( + in_channels, + out_channels, + 3, + downsample=True, + interpolation_mode=interpolation_mode, + bias=True, + activate=True) + self.skip = ConvLayer( + in_channels, + out_channels, + 1, + downsample=True, + interpolation_mode=interpolation_mode, + bias=False, + activate=False) + + def forward(self, x): + out = self.conv1(x) + out = self.conv2(out) + skip = self.skip(x) + out = (out + skip) / math.sqrt(2) + return out diff --git a/gfpgan/gfpgan/archs/stylegan2_clean_arch.py b/gfpgan/gfpgan/archs/stylegan2_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2ee94e50401b95e4c9997adef5581d521d725f --- /dev/null +++ b/gfpgan/gfpgan/archs/stylegan2_clean_arch.py @@ -0,0 +1,368 @@ +import math +import random +import torch +from basicsr.archs.arch_util import default_init_weights +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn +from torch.nn import functional as F + + +class NormStyleCode(nn.Module): + + def forward(self, x): + """Normalize the style codes. + + Args: + x (Tensor): Style codes with shape (b, c). + + Returns: + Tensor: Normalized tensor. + """ + return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) + + +class ModulatedConv2d(nn.Module): + """Modulated Conv2d used in StyleGAN2. + + There is no bias in ModulatedConv2d. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether to demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + eps (float): A value added to the denominator for numerical stability. Default: 1e-8. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear') + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) / + math.sqrt(in_channels * kernel_size**2)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + """Forward function. + + Args: + x (Tensor): Tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + + Returns: + Tensor: Modulated tensor after convolution. + """ + b, c, h, w = x.shape # c = c_in + # weight modulation + style = self.modulation(style).view(b, 1, c, 1, 1) + # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) + weight = self.weight * style # (b, c_out, c_in, k, k) + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + # upsample or downsample if necessary + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + # weight: (b*c_out, c_in, k, k), groups=b + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + """Style conv used in StyleGAN2. + + Args: + in_channels (int): Channel number of the input. + out_channels (int): Channel number of the output. + kernel_size (int): Size of the convolving kernel. + num_style_feat (int): Channel number of style features. + demodulate (bool): Whether demodulate in the conv layer. Default: True. + sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None. + """ + + def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + """To RGB (image space) from features. + + Args: + in_channels (int): Channel number of input. + num_style_feat (int): Channel number of style features. + upsample (bool): Whether to upsample. Default: True. + """ + + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + """Forward function. + + Args: + x (Tensor): Feature tensor with shape (b, c, h, w). + style (Tensor): Tensor with shape (b, num_style_feat). + skip (Tensor): Base/skip tensor. Default: None. + + Returns: + Tensor: RGB images. + """ + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) + out = out + skip + return out + + +class ConstantInput(nn.Module): + """Constant input. + + Args: + num_channel (int): Channel number of constant input. + size (int): Spatial size of constant input. + """ + + def __init__(self, num_channel, size): + super(ConstantInput, self).__init__() + self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) + + def forward(self, batch): + out = self.weight.repeat(batch, 1, 1, 1) + return out + + +@ARCH_REGISTRY.register() +class StyleGAN2GeneratorClean(nn.Module): + """Clean version of StyleGAN2 Generator. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + narrow (float): Narrow ratio for channels. Default: 1.0. + """ + + def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1): + super(StyleGAN2GeneratorClean, self).__init__() + # Style MLP layers + self.num_style_feat = num_style_feat + style_mlp_layers = [NormStyleCode()] + for i in range(num_mlp): + style_mlp_layers.extend( + [nn.Linear(num_style_feat, num_style_feat, bias=True), + nn.LeakyReLU(negative_slope=0.2, inplace=True)]) + self.style_mlp = nn.Sequential(*style_mlp_layers) + # initialization + default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu') + + # channel list + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + self.channels = channels + + self.constant_input = ConstantInput(channels['4'], size=4) + self.style_conv1 = StyleConv( + channels['4'], + channels['4'], + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None) + self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False) + + self.log_size = int(math.log(out_size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + self.num_latent = self.log_size * 2 - 2 + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channels = channels['4'] + # noise + for layer_idx in range(self.num_layers): + resolution = 2**((layer_idx + 5) // 2) + shape = [1, 1, resolution, resolution] + self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) + # style convs and to_rgbs + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample')) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def make_noise(self): + """Make noise for noise injection.""" + device = self.constant_input.weight.device + noises = [torch.randn(1, 1, 4, 4, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def get_latent(self, x): + return self.style_mlp(x) + + def mean_latent(self, num_latent): + latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) + latent = self.style_mlp(latent_in).mean(0, keepdim=True) + return latent + + def forward(self, + styles, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorClean. + + Args: + styles (list[Tensor]): Sample codes of styles. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None diff --git a/gfpgan/gfpgan/data/__init__.py b/gfpgan/gfpgan/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69fd9f9026407c4d185f86b122000485b06fd986 --- /dev/null +++ b/gfpgan/gfpgan/data/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import dataset modules for registry +# scan all the files that end with '_dataset.py' under the data folder +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] diff --git a/gfpgan/gfpgan/data/ffhq_degradation_dataset.py b/gfpgan/gfpgan/data/ffhq_degradation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..64e5755e1211f171cb2a883d47e8d253061f90aa --- /dev/null +++ b/gfpgan/gfpgan/data/ffhq_degradation_dataset.py @@ -0,0 +1,230 @@ +import cv2 +import math +import numpy as np +import os.path as osp +import torch +import torch.utils.data as data +from basicsr.data import degradations as degradations +from basicsr.data.data_util import paths_from_folder +from basicsr.data.transforms import augment +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, + normalize) + + +@DATASET_REGISTRY.register() +class FFHQDegradationDataset(data.Dataset): + """FFHQ dataset for GFPGAN. + + It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. + + Args: + opt (dict): Config for train datasets. It contains the following keys: + dataroot_gt (str): Data root path for gt. + io_backend (dict): IO backend type and other kwarg. + mean (list | tuple): Image mean. + std (list | tuple): Image std. + use_hflip (bool): Whether to horizontally flip. + Please see more options in the codes. + """ + + def __init__(self, opt): + super(FFHQDegradationDataset, self).__init__() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + + self.gt_folder = opt['dataroot_gt'] + self.mean = opt['mean'] + self.std = opt['std'] + self.out_size = opt['out_size'] + + self.crop_components = opt.get('crop_components', False) # facial components + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions + + if self.crop_components: + # load component list from a pre-process pth files + self.components_list = torch.load(opt.get('component_path')) + + # file client (lmdb io backend) + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + # disk backend: scan file list from a folder + self.paths = paths_from_folder(self.gt_folder) + + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.blur_sigma = opt['blur_sigma'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob') + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + # to gray + self.gray_prob = opt.get('gray_prob') + + logger = get_root_logger() + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + def get_component_coordinates(self, index, status): + """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" + components_bbox = self.components_list[f'{index:08d}'] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] + components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] + + # get coordinates + locations = [] + for part in ['left_eye', 'right_eye', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations.append(loc) + return locations + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. + gt_path = self.paths[index] + img_bytes = self.file_client.get(gt_path) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + h, w, _ = img_gt.shape + + # get facial component coordinates + if self.crop_components: + locations = self.get_component_coordinates(index, status) + loc_left_eye, loc_right_eye, loc_mouth = locations + + # ------------------------ generate lq image ------------------------ # + # blur + kernel = degradations.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + noise_range=None) + img_lq = cv2.filter2D(img_gt, -1, kernel) + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) + # noise + if self.noise_range is not None: + img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) + # jpeg compression + if self.jpeg_range is not None: + img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) + + # resize to original size + img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_lq = self.color_jitter(img_lq, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) + img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) + if self.opt.get('gt_gray'): # whether convert GT to gray images + img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) + img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels + + # BGR to RGB, HWC to CHW, numpy to tensor + img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) + + # round and clip + img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. + + # normalize + normalize(img_gt, self.mean, self.std, inplace=True) + normalize(img_lq, self.mean, self.std, inplace=True) + + if self.crop_components: + return_dict = { + 'lq': img_lq, + 'gt': img_gt, + 'gt_path': gt_path, + 'loc_left_eye': loc_left_eye, + 'loc_right_eye': loc_right_eye, + 'loc_mouth': loc_mouth + } + return return_dict + else: + return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} + + def __len__(self): + return len(self.paths) diff --git a/gfpgan/gfpgan/models/__init__.py b/gfpgan/gfpgan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6afad57a3794b867dabbdb617a16355a24d6a8b3 --- /dev/null +++ b/gfpgan/gfpgan/models/__init__.py @@ -0,0 +1,10 @@ +import importlib +from basicsr.utils import scandir +from os import path as osp + +# automatically scan and import model modules for registry +# scan all the files that end with '_model.py' under the model folder +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] diff --git a/gfpgan/gfpgan/models/gfpgan_model.py b/gfpgan/gfpgan/models/gfpgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fb8c953b1ef67b457f56492ad3291d6e5f126d --- /dev/null +++ b/gfpgan/gfpgan/models/gfpgan_model.py @@ -0,0 +1,579 @@ +import math +import os.path as osp +import torch +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.losses.gan_loss import r1_penalty +from basicsr.metrics import calculate_metric +from basicsr.models.base_model import BaseModel +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from collections import OrderedDict +from torch.nn import functional as F +from torchvision.ops import roi_align +from tqdm import tqdm + + +@MODEL_REGISTRY.register() +class GFPGANModel(BaseModel): + """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior""" + + def __init__(self, opt): + super(GFPGANModel, self).__init__(opt) + self.idx = 0 # it is used for saving data for check + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + train_opt = self.opt['train'] + + # ----------- define net_d ----------- # + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + # ----------- define net_g with Exponential Moving Average (EMA) ----------- # + # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + + self.net_g.train() + self.net_d.train() + self.net_g_ema.eval() + + # ----------- facial component networks ----------- # + if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): + self.use_facial_disc = True + else: + self.use_facial_disc = False + + if self.use_facial_disc: + # left eye + self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) + self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) + self.print_network(self.net_d_left_eye) + load_path = self.opt['path'].get('pretrain_network_d_left_eye') + if load_path is not None: + self.load_network(self.net_d_left_eye, load_path, True, 'params') + # right eye + self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) + self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) + self.print_network(self.net_d_right_eye) + load_path = self.opt['path'].get('pretrain_network_d_right_eye') + if load_path is not None: + self.load_network(self.net_d_right_eye, load_path, True, 'params') + # mouth + self.net_d_mouth = build_network(self.opt['network_d_mouth']) + self.net_d_mouth = self.model_to_device(self.net_d_mouth) + self.print_network(self.net_d_mouth) + load_path = self.opt['path'].get('pretrain_network_d_mouth') + if load_path is not None: + self.load_network(self.net_d_mouth, load_path, True, 'params') + + self.net_d_left_eye.train() + self.net_d_right_eye.train() + self.net_d_mouth.train() + + # ----------- define facial component gan loss ----------- # + self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device) + + # ----------- define losses ----------- # + # pixel loss + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + # perceptual loss + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + # L1 loss is used in pyramid loss, component style loss and identity loss + self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) + + # gan loss (wgan) + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + # ----------- define identity loss ----------- # + if 'network_identity' in self.opt: + self.use_identity = True + else: + self.use_identity = False + + if self.use_identity: + # define identity network + self.network_identity = build_network(self.opt['network_identity']) + self.network_identity = self.model_to_device(self.network_identity) + self.print_network(self.network_identity) + load_path = self.opt['path'].get('pretrain_network_identity') + if load_path is not None: + self.load_network(self.network_identity, load_path, True, None) + self.network_identity.eval() + for param in self.network_identity.parameters(): + param.requires_grad = False + + # regularization weights + self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) + self.net_d_reg_every = train_opt['net_d_reg_every'] + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + + # ----------- optimizer g ----------- # + net_g_reg_ratio = 1 + normal_params = [] + for _, param in self.net_g.named_parameters(): + normal_params.append(param) + optim_params_g = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_g']['lr'] + }] + optim_type = train_opt['optim_g'].pop('type') + lr = train_opt['optim_g']['lr'] * net_g_reg_ratio + betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) + self.optimizers.append(self.optimizer_g) + + # ----------- optimizer d ----------- # + net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) + normal_params = [] + for _, param in self.net_d.named_parameters(): + normal_params.append(param) + optim_params_d = [{ # add normal params first + 'params': normal_params, + 'lr': train_opt['optim_d']['lr'] + }] + optim_type = train_opt['optim_d'].pop('type') + lr = train_opt['optim_d']['lr'] * net_d_reg_ratio + betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) + self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) + self.optimizers.append(self.optimizer_d) + + # ----------- optimizers for facial component networks ----------- # + if self.use_facial_disc: + # setup optimizers for facial component discriminators + optim_type = train_opt['optim_component'].pop('type') + lr = train_opt['optim_component']['lr'] + # left eye + self.optimizer_d_left_eye = self.get_optimizer( + optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_left_eye) + # right eye + self.optimizer_d_right_eye = self.get_optimizer( + optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_right_eye) + # mouth + self.optimizer_d_mouth = self.get_optimizer( + optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99)) + self.optimizers.append(self.optimizer_d_mouth) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + if 'loc_left_eye' in data: + # get facial component locations, shape (batch, 4) + self.loc_left_eyes = data['loc_left_eye'] + self.loc_right_eyes = data['loc_right_eye'] + self.loc_mouths = data['loc_mouth'] + + # uncomment to check data + # import torchvision + # if self.opt['rank'] == 0: + # import os + # os.makedirs('tmp/gt', exist_ok=True) + # os.makedirs('tmp/lq', exist_ok=True) + # print(self.idx) + # torchvision.utils.save_image( + # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # torchvision.utils.save_image( + # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # self.idx = self.idx + 1 + + def construct_img_pyramid(self): + """Construct image pyramid for intermediate restoration loss""" + pyramid_gt = [self.gt] + down_img = self.gt + for _ in range(0, self.log_size - 3): + down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False) + pyramid_gt.insert(0, down_img) + return pyramid_gt + + def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): + face_ratio = int(self.opt['network_g']['out_size'] / 512) + eye_out_size *= face_ratio + mouth_out_size *= face_ratio + + rois_eyes = [] + rois_mouths = [] + for b in range(self.loc_left_eyes.size(0)): # loop for batch size + # left eye and right eye + img_inds = self.loc_left_eyes.new_full((2, 1), b) + bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4) + rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) + rois_eyes.append(rois) + # mouse + img_inds = self.loc_left_eyes.new_full((1, 1), b) + rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) + rois_mouths.append(rois) + + rois_eyes = torch.cat(rois_eyes, 0).to(self.device) + rois_mouths = torch.cat(rois_mouths, 0).to(self.device) + + # real images + all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes_gt = all_eyes[0::2, :, :, :] + self.right_eyes_gt = all_eyes[1::2, :, :, :] + self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + # output + all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio + self.left_eyes = all_eyes[0::2, :, :, :] + self.right_eyes = all_eyes[1::2, :, :, :] + self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + def gray_resize_for_identity(self, out, size=128): + out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) + out_gray = out_gray.unsqueeze(1) + out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) + return out_gray + + def optimize_parameters(self, current_iter): + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + self.optimizer_g.zero_grad() + + # do not update facial component net_d + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = False + for p in self.net_d_right_eye.parameters(): + p.requires_grad = False + for p in self.net_d_mouth.parameters(): + p.requires_grad = False + + # image pyramid loss weight + pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0) + if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')): + pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error + if pyramid_loss_weight > 0: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) + pyramid_gt = self.construct_img_pyramid() + else: + self.output, out_rgbs = self.net_g(self.lq, return_rgb=False) + + # get roi-align regions + if self.use_facial_disc: + self.get_roi_regions(eye_out_size=80, mouth_out_size=120) + + l_g_total = 0 + loss_dict = OrderedDict() + if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + + # image pyramid loss + if pyramid_loss_weight > 0: + for i in range(0, self.log_size - 2): + l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight + l_g_total += l_pyramid + loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid + + # perceptual loss + if self.cri_perceptual: + l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) + if l_g_percep is not None: + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + if l_g_style is not None: + l_g_total += l_g_style + loss_dict['l_g_style'] = l_g_style + + # gan loss + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan'] = l_g_gan + + # facial component loss + if self.use_facial_disc: + # left eye + fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_left_eye'] = l_g_gan + # right eye + fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True) + l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_right_eye'] = l_g_gan + # mouth + fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True) + l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) + l_g_total += l_g_gan + loss_dict['l_g_gan_mouth'] = l_g_gan + + if self.opt['train'].get('comp_style_weight', 0) > 0: + # get gt feat + _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True) + _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True) + _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True) + + def _comp_style(feat, feat_gt, criterion): + return criterion(self._gram_mat(feat[0]), self._gram_mat( + feat_gt[0].detach())) * 0.5 + criterion( + self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach())) + + # facial component style loss + comp_style_loss = 0 + comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1) + comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1) + comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight'] + l_g_total += comp_style_loss + loss_dict['l_g_comp_style_loss'] = comp_style_loss + + # identity loss + if self.use_identity: + identity_weight = self.opt['train']['identity_weight'] + # get gray images and resize + out_gray = self.gray_resize_for_identity(self.output) + gt_gray = self.gray_resize_for_identity(self.gt) + + identity_gt = self.network_identity(gt_gray).detach() + identity_out = self.network_identity(out_gray) + l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight + l_g_total += l_identity + loss_dict['l_identity'] = l_identity + + l_g_total.backward() + self.optimizer_g.step() + + # EMA + self.model_ema(decay=0.5**(32 / (10 * 1000))) + + # ----------- optimize net_d ----------- # + for p in self.net_d.parameters(): + p.requires_grad = True + self.optimizer_d.zero_grad() + if self.use_facial_disc: + for p in self.net_d_left_eye.parameters(): + p.requires_grad = True + for p in self.net_d_right_eye.parameters(): + p.requires_grad = True + for p in self.net_d_mouth.parameters(): + p.requires_grad = True + self.optimizer_d_left_eye.zero_grad() + self.optimizer_d_right_eye.zero_grad() + self.optimizer_d_mouth.zero_grad() + + fake_d_pred = self.net_d(self.output.detach()) + real_d_pred = self.net_d(self.gt) + l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d'] = l_d + # In WGAN, real_score should be positive and fake_score should be negative + loss_dict['real_score'] = real_d_pred.detach().mean() + loss_dict['fake_score'] = fake_d_pred.detach().mean() + l_d.backward() + + # regularization loss + if current_iter % self.net_d_reg_every == 0: + self.gt.requires_grad = True + real_pred = self.net_d(self.gt) + l_d_r1 = r1_penalty(real_pred, self.gt) + l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) + loss_dict['l_d_r1'] = l_d_r1.detach().mean() + l_d_r1.backward() + + self.optimizer_d.step() + + # optimize facial component discriminators + if self.use_facial_disc: + # left eye + fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach()) + real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt) + l_d_left_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_left_eye'] = l_d_left_eye + l_d_left_eye.backward() + # right eye + fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach()) + real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt) + l_d_right_eye = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_right_eye'] = l_d_right_eye + l_d_right_eye.backward() + # mouth + fake_d_pred, _ = self.net_d_mouth(self.mouths.detach()) + real_d_pred, _ = self.net_d_mouth(self.mouths_gt) + l_d_mouth = self.cri_component( + real_d_pred, True, is_disc=True) + self.cri_gan( + fake_d_pred, False, is_disc=True) + loss_dict['l_d_mouth'] = l_d_mouth + l_d_mouth.backward() + + self.optimizer_d_left_eye.step() + self.optimizer_d_right_eye.step() + self.optimizer_d_mouth.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _ = self.net_g_ema(self.lq) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _ = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + use_pbar = self.opt['val'].get('pbar', False) + + if with_metrics: + if not hasattr(self, 'metric_results'): # only execute in the first run + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + # initialize the best metric results for each dataset_name (supporting multiple validation datasets) + self._initialize_best_metric_results(dataset_name) + # zero self.metric_results + self.metric_results = {metric: 0 for metric in self.metric_results} + + metric_data = dict() + if use_pbar: + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1)) + metric_data['img'] = sr_img + if hasattr(self, 'gt'): + gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1)) + metric_data['img2'] = gt_img + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + self.metric_results[name] += calculate_metric(metric_data, opt_) + if use_pbar: + pbar.update(1) + pbar.set_description(f'Test {img_name}') + if use_pbar: + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + # update the best metric result + self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}' + if hasattr(self, 'best_metric_results'): + log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' + f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') + log_str += '\n' + + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) + + def save(self, epoch, current_iter): + # save net_g and net_d + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + self.save_network(self.net_d, 'net_d', current_iter) + # save component discriminators + if self.use_facial_disc: + self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter) + self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter) + self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter) + # save training state + self.save_training_state(epoch, current_iter) diff --git a/gfpgan/gfpgan/train.py b/gfpgan/gfpgan/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5f1f909ae15a8d830ef65dcb43436d4f4ee7ae --- /dev/null +++ b/gfpgan/gfpgan/train.py @@ -0,0 +1,11 @@ +# flake8: noqa +import os.path as osp +from basicsr.train import train_pipeline + +import gfpgan.archs +import gfpgan.data +import gfpgan.models + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/gfpgan/gfpgan/utils.py b/gfpgan/gfpgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35658c5284330a6dba3e6589647c395c4348d17d --- /dev/null +++ b/gfpgan/gfpgan/utils.py @@ -0,0 +1,144 @@ +import cv2 +import os +import torch +from basicsr.utils import img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from torchvision.transforms.functional import normalize + +from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class GFPGANer(): + """Helper for restoration with GFPGAN. + + It will detect and crop faces, and then resize the faces to 512x512. + GFPGAN is used to restored the resized faces. + The background is upsampled with the bg_upsampler. + Finally, the faces will be pasted back to the upsample background image. + + Args: + model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). + upscale (float): The upscale of the final output. Default: 2. + arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + bg_upsampler (nn.Module): The upsampler for the background. Default: None. + """ + + def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): + self.upscale = upscale + self.bg_upsampler = bg_upsampler + + # initialize model + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + # initialize the GFP-GAN + if arch == 'clean': + self.gfpgan = GFPGANv1Clean( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'bilinear': + self.gfpgan = GFPGANBilinear( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'original': + self.gfpgan = GFPGANv1( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=True, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + # initialize face helper + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=self.device) + + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) + loadnet = torch.load(model_path) + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + self.gfpgan.load_state_dict(loadnet[keyname], strict=True) + self.gfpgan.eval() + self.gfpgan = self.gfpgan.to(self.device) + + @torch.no_grad() + def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True): + self.face_helper.clean_all() + + if has_aligned: # the inputs are already aligned + img = cv2.resize(img, (512, 512)) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) + # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels + # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration + for cropped_face in self.face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + output = self.gfpgan(cropped_face_t, return_rgb=False)[0] + # convert to image + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + except RuntimeError as error: + print(f'\tFailed inference for GFPGAN: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + if not has_aligned and paste_back: + # upsample the background + if self.bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] + else: + bg_img = None + + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) + return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img + else: + return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/gfpgan/gfpgan/weights/README.md b/gfpgan/gfpgan/weights/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4d7b7e642591ef88575d9e6c360a4d29e0cc1a4f --- /dev/null +++ b/gfpgan/gfpgan/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded weights to this folder. diff --git a/gfpgan/inference_gfpgan.py b/gfpgan/inference_gfpgan.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6efad2e6f306b48343c91407d5162dba8ec9fe --- /dev/null +++ b/gfpgan/inference_gfpgan.py @@ -0,0 +1,158 @@ +import argparse +import cv2 +import glob +import numpy as np +import os +import torch +from basicsr.utils import imwrite +from tqdm import tqdm + +from gfpgan import GFPGANer + + +def main(): + """Inference demo for GFPGAN (for users). + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '-i', + '--input', + type=str, + default='inputs/whole_imgs', + help='Input image or folder. Default: inputs/whole_imgs') + parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results') + # we use version to select models, which is more user-friendly + parser.add_argument( + '-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3') + parser.add_argument( + '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2') + + parser.add_argument( + '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan') + parser.add_argument( + '--bg_tile', + type=int, + default=400, + help='Tile size for background sampler, 0 for no tile during testing. Default: 400') + parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') + parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face') + parser.add_argument('--aligned', action='store_true', help='Input are aligned faces') + parser.add_argument('--save_faces', default=False, help='Save the restored faces') + parser.add_argument( + '--ext', + type=str, + default='auto', + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto') + args = parser.parse_args() + + args = parser.parse_args() + + # ------------------------ input & output ------------------------ + if args.input.endswith('/'): + args.input = args.input[:-1] + if os.path.isfile(args.input): + img_list = [args.input] + else: + img_list = sorted(glob.glob(os.path.join(args.input, '*'))) + + os.makedirs(args.output, exist_ok=True) + + # ------------------------ set up background upsampler ------------------------ + if args.bg_upsampler == 'realesrgan': + if not torch.cuda.is_available(): # CPU + import warnings + warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' + 'If you really want to use it, please modify the corresponding codes.') + bg_upsampler = None + else: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + bg_upsampler = RealESRGANer( + scale=2, + model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + model=model, + tile=args.bg_tile, + tile_pad=10, + pre_pad=0, + half=True) # need to set False in CPU mode + else: + bg_upsampler = None + + # ------------------------ set up GFPGAN restorer ------------------------ + if args.version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif args.version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif args.version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + else: + raise ValueError(f'Wrong model version {args.version}.') + + # determine model paths + model_path = os.path.join('experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + + restorer = GFPGANer( + model_path=model_path, + upscale=args.upscale, + arch=arch, + channel_multiplier=channel_multiplier, + bg_upsampler=bg_upsampler) + + # ------------------------ restore ------------------------ + for img_path in tqdm(img_list): + # read image + img_name = os.path.basename(img_path) + print(f'Processing {img_name} ...') + basename, ext = os.path.splitext(img_name) + input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) + + # restore faces and background if necessary + cropped_faces, restored_faces, restored_img = restorer.enhance( + input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=True) + + # save faces + if(args.save_faces): + for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): + # save cropped face + save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png') + imwrite(cropped_face, save_crop_path) + # save restored face + if args.suffix is not None: + save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png' + else: + save_face_name = f'{basename}_{idx:02d}.png' + save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name) + imwrite(restored_face, save_restore_path) + # save comparison image + cmp_img = np.concatenate((cropped_face, restored_face), axis=1) + imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png')) + + # save restored img + if restored_img is not None: + if args.ext == 'auto': + extension = ext[1:] + else: + extension = args.ext + + if args.suffix is not None: + save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}') + else: + save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}') + imwrite(restored_img, save_restore_path) + + print(f'Results are in the [{args.output}] folder.') + + +if __name__ == '__main__': + main() diff --git a/gfpgan/options/train_gfpgan_v1.yml b/gfpgan/options/train_gfpgan_v1.yml new file mode 100644 index 0000000000000000000000000000000000000000..aa5212a81de362daaef306e203f03cc665186d47 --- /dev/null +++ b/gfpgan/options/train_gfpgan_v1.yml @@ -0,0 +1,216 @@ +# general settings +name: train_GFPGANv1_512 +model_type: GFPGANModel +num_gpu: auto # officially, we use 4 GPUs +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: FFHQ + type: FFHQDegradationDataset + # dataroot_gt: datasets/ffhq/ffhq_512.lmdb + dataroot_gt: datasets/ffhq/ffhq_512 + io_backend: + # type: lmdb + type: disk + + use_hflip: true + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + out_size: 512 + + blur_kernel_size: 41 + kernel_list: ['iso', 'aniso'] + kernel_prob: [0.5, 0.5] + blur_sigma: [0.1, 10] + downsample_range: [0.8, 8] + noise_range: [0, 20] + jpeg_range: [60, 100] + + # color jitter and gray + color_jitter_prob: 0.3 + color_jitter_shift: 20 + color_jitter_pt_prob: 0.3 + gray_prob: 0.01 + + # If you do not want colorization, please set + # color_jitter_prob: ~ + # color_jitter_pt_prob: ~ + # gray_prob: 0.01 + # gt_gray: True + + crop_components: true + component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth + eye_enlarge_ratio: 1.4 + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 3 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + # Please modify accordingly to use your own validation + # Or comment the val block if do not need validation during training + name: validation + type: PairedImageDataset + dataroot_lq: datasets/faces/validation/input + dataroot_gt: datasets/faces/validation/reference + io_backend: + type: disk + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + scale: 1 + +# network structures +network_g: + type: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 1 + sft_half: true + +network_d: + type: StyleGAN2Discriminator + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + +network_d_left_eye: + type: FacialComponentDiscriminator + +network_d_right_eye: + type: FacialComponentDiscriminator + +network_d_mouth: + type: FacialComponentDiscriminator + +network_identity: + type: ResNetArcFace + block: IRBlock + layers: [2, 2, 2, 2] + use_se: False + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + pretrain_network_d_left_eye: ~ + pretrain_network_d_right_eye: ~ + pretrain_network_d_mouth: ~ + pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth + # resume + resume_state: ~ + ignore_resume_networks: ['network_identity'] + +# training settings +train: + optim_g: + type: Adam + lr: !!float 2e-3 + optim_d: + type: Adam + lr: !!float 2e-3 + optim_component: + type: Adam + lr: !!float 2e-3 + + scheduler: + type: MultiStepLR + milestones: [600000, 700000] + gamma: 0.5 + + total_iter: 800000 + warmup_iter: -1 # no warm up + + # losses + # pixel loss + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-1 + reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + + # image pyramid loss + pyramid_loss_weight: 1 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + # facial component loss + gan_component_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1 + comp_style_weight: 200 + # identity loss + identity_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 16 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: true + + metrics: + psnr: # metric name + type: calculate_psnr + crop_border: 0 + test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 + +find_unused_parameters: true diff --git a/gfpgan/options/train_gfpgan_v1_simple.yml b/gfpgan/options/train_gfpgan_v1_simple.yml new file mode 100644 index 0000000000000000000000000000000000000000..3807575826a5e7ed97335f607c091c8a4039a213 --- /dev/null +++ b/gfpgan/options/train_gfpgan_v1_simple.yml @@ -0,0 +1,182 @@ +# general settings +name: train_GFPGANv1_512_simple +model_type: GFPGANModel +num_gpu: auto # officially, we use 4 GPUs +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: FFHQ + type: FFHQDegradationDataset + # dataroot_gt: datasets/ffhq/ffhq_512.lmdb + dataroot_gt: datasets/ffhq/ffhq_512 + io_backend: + # type: lmdb + type: disk + + use_hflip: true + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + out_size: 512 + + blur_kernel_size: 41 + kernel_list: ['iso', 'aniso'] + kernel_prob: [0.5, 0.5] + blur_sigma: [0.1, 10] + downsample_range: [0.8, 8] + noise_range: [0, 20] + jpeg_range: [60, 100] + + # color jitter and gray + color_jitter_prob: 0.3 + color_jitter_shift: 20 + color_jitter_pt_prob: 0.3 + gray_prob: 0.01 + + # If you do not want colorization, please set + # color_jitter_prob: ~ + # color_jitter_pt_prob: ~ + # gray_prob: 0.01 + # gt_gray: True + + # data loader + use_shuffle: true + num_worker_per_gpu: 6 + batch_size_per_gpu: 3 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + val: + # Please modify accordingly to use your own validation + # Or comment the val block if do not need validation during training + name: validation + type: PairedImageDataset + dataroot_lq: datasets/faces/validation/input + dataroot_gt: datasets/faces/validation/reference + io_backend: + type: disk + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + scale: 1 + +# network structures +network_g: + type: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 1 + sft_half: true + +network_d: + type: StyleGAN2Discriminator + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + resume_state: ~ + +# training settings +train: + optim_g: + type: Adam + lr: !!float 2e-3 + optim_d: + type: Adam + lr: !!float 2e-3 + optim_component: + type: Adam + lr: !!float 2e-3 + + scheduler: + type: MultiStepLR + milestones: [600000, 700000] + gamma: 0.5 + + total_iter: 800000 + warmup_iter: -1 # no warm up + + # losses + # pixel loss + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-1 + reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + + # image pyramid loss + pyramid_loss_weight: 1 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 16 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: true + + metrics: + psnr: # metric name + type: calculate_psnr + crop_border: 0 + test_y_channel: false + +# logging settings +logger: + print_freq: 100 + save_checkpoint_freq: !!float 5e3 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29500 + +find_unused_parameters: true diff --git a/gfpgan/requirements.txt b/gfpgan/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3edf9e0352b05a48a6ebeb7c85f10a934e672fe2 --- /dev/null +++ b/gfpgan/requirements.txt @@ -0,0 +1,12 @@ +basicsr>=1.3.4.0 +facexlib>=0.2.3 +lmdb +numpy<1.21 # numba requires numpy<1.21,>=1.17 +opencv-python +pyyaml +scipy +tb-nightly +torch>=1.7 +torchvision +tqdm +yapf diff --git a/gfpgan/scripts/convert_gfpganv_to_clean.py b/gfpgan/scripts/convert_gfpganv_to_clean.py new file mode 100644 index 0000000000000000000000000000000000000000..8fdccb6195c29e78cec2ac8dcc6f9ccb604e35ca --- /dev/null +++ b/gfpgan/scripts/convert_gfpganv_to_clean.py @@ -0,0 +1,164 @@ +import argparse +import math +import torch + +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + + +def modify_checkpoint(checkpoint_bilinear, checkpoint_clean): + for ori_k, ori_v in checkpoint_bilinear.items(): + if 'stylegan_decoder' in ori_k: + if 'style_mlp' in ori_k: # style_mlp_layers + lr_mul = 0.01 + prefix, name, idx, var = ori_k.split('.') + idx = (int(idx) * 2) - 1 + crt_k = f'{prefix}.{name}.{idx}.{var}' + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale * 2**0.5 + else: + crt_v = ori_v * lr_mul * 2**0.5 + checkpoint_clean[crt_k] = crt_v + elif 'modulation' in ori_k: # modulation in StyleConv + lr_mul = 1 + crt_k = ori_k + var = ori_k.split('.')[-1] + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale + else: + crt_v = ori_v * lr_mul + checkpoint_clean[crt_k] = crt_v + elif 'style_conv' in ori_k: + # StyleConv in style_conv1 and style_convs + if 'activate' in ori_k: # FusedLeakyReLU + # eg. style_conv1.activate.bias + # eg. style_convs.13.activate.bias + split_rlt = ori_k.split('.') + if len(split_rlt) == 4: + prefix, name, _, var = split_rlt + crt_k = f'{prefix}.{name}.{var}' + elif len(split_rlt) == 5: + prefix, name, idx, _, var = split_rlt + crt_k = f'{prefix}.{name}.{idx}.{var}' + crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU + c = crt_v.size(0) + checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1) + elif 'modulated_conv' in ori_k: + # eg. style_conv1.modulated_conv.weight + # eg. style_convs.13.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + elif 'weight' in ori_k: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs + if 'modulated_conv' in ori_k: + # eg. to_rgb1.modulated_conv.weight + # eg. to_rgbs.5.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + # end of 'stylegan_decoder' + elif 'conv_body_first' in ori_k or 'final_conv' in ori_k: + # key name + name, _, var = ori_k.split('.') + crt_k = f'{name}.{var}' + # weight and bias + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + else: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'conv_body' in ori_k: + if 'conv_body_up' in ori_k: + ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight') + ori_k = ori_k.replace('skip.weight', 'skip.1.weight') + name1, idx1, name2, _, var = ori_k.split('.') + crt_k = f'{name1}.{idx1}.{name2}.{var}' + if name2 == 'skip': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale / 2**0.5 + else: + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + if 'conv1' in ori_k: + checkpoint_clean[crt_k] *= 2**0.5 + elif 'toRGB' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'final_linear' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + _, c_in = ori_v.size() + scale = 1 / math.sqrt(c_in) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'condition' in ori_k: + crt_k = ori_k + if '0.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + elif '0.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif '2.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + elif '2.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v + + return checkpoint_clean + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ori_path', type=str, help='Path to the original model') + parser.add_argument('--narrow', type=float, default=1) + parser.add_argument('--channel_multiplier', type=float, default=2) + parser.add_argument('--save_path', type=str) + args = parser.parse_args() + + ori_ckpt = torch.load(args.ori_path)['params_ema'] + + net = GFPGANv1Clean( + 512, + num_style_feat=512, + channel_multiplier=args.channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + # for stylegan decoder + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=args.narrow, + sft_half=True) + crt_ckpt = net.state_dict() + + crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt) + print(f'Save to {args.save_path}.') + torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False) diff --git a/gfpgan/scripts/parse_landmark.py b/gfpgan/scripts/parse_landmark.py new file mode 100644 index 0000000000000000000000000000000000000000..74e2ff9e130ad4f2395c9666dca3ba78526d7a8a --- /dev/null +++ b/gfpgan/scripts/parse_landmark.py @@ -0,0 +1,85 @@ +import cv2 +import json +import numpy as np +import os +import torch +from basicsr.utils import FileClient, imfrombytes +from collections import OrderedDict + +# ---------------------------- This script is used to parse facial landmarks ------------------------------------- # +# Configurations +save_img = False +scale = 0.5 # 0.5 for official FFHQ (512x512), 1 for others +enlarge_ratio = 1.4 # only for eyes +json_path = 'ffhq-dataset-v2.json' +face_path = 'datasets/ffhq/ffhq_512.lmdb' +save_path = './FFHQ_eye_mouth_landmarks_512.pth' + +print('Load JSON metadata...') +# use the official json file in FFHQ dataset +with open(json_path, 'rb') as f: + json_data = json.load(f, object_pairs_hook=OrderedDict) + +print('Open LMDB file...') +# read ffhq images +file_client = FileClient('lmdb', db_paths=face_path) +with open(os.path.join(face_path, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + +save_dict = {} + +for item_idx, item in enumerate(json_data.values()): + print(f'\r{item_idx} / {len(json_data)}, {item["image"]["file_path"]} ', end='', flush=True) + + # parse landmarks + lm = np.array(item['image']['face_landmarks']) + lm = lm * scale + + item_dict = {} + # get image + if save_img: + img_bytes = file_client.get(paths[item_idx]) + img = imfrombytes(img_bytes, float32=True) + + # get landmarks for each component + map_left_eye = list(range(36, 42)) + map_right_eye = list(range(42, 48)) + map_mouth = list(range(48, 68)) + + # eye_left + mean_left_eye = np.mean(lm[map_left_eye], 0) # (x, y) + half_len_left_eye = np.max((np.max(np.max(lm[map_left_eye], 0) - np.min(lm[map_left_eye], 0)) / 2, 16)) + item_dict['left_eye'] = [mean_left_eye[0], mean_left_eye[1], half_len_left_eye] + # mean_left_eye[0] = 512 - mean_left_eye[0] # for testing flip + half_len_left_eye *= enlarge_ratio + loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int) + if save_img: + eye_left_img = img[loc_left_eye[1]:loc_left_eye[3], loc_left_eye[0]:loc_left_eye[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_eye_left.png', eye_left_img * 255) + + # eye_right + mean_right_eye = np.mean(lm[map_right_eye], 0) + half_len_right_eye = np.max((np.max(np.max(lm[map_right_eye], 0) - np.min(lm[map_right_eye], 0)) / 2, 16)) + item_dict['right_eye'] = [mean_right_eye[0], mean_right_eye[1], half_len_right_eye] + # mean_right_eye[0] = 512 - mean_right_eye[0] # # for testing flip + half_len_right_eye *= enlarge_ratio + loc_right_eye = np.hstack( + (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int) + if save_img: + eye_right_img = img[loc_right_eye[1]:loc_right_eye[3], loc_right_eye[0]:loc_right_eye[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_eye_right.png', eye_right_img * 255) + + # mouth + mean_mouth = np.mean(lm[map_mouth], 0) + half_len_mouth = np.max((np.max(np.max(lm[map_mouth], 0) - np.min(lm[map_mouth], 0)) / 2, 16)) + item_dict['mouth'] = [mean_mouth[0], mean_mouth[1], half_len_mouth] + # mean_mouth[0] = 512 - mean_mouth[0] # for testing flip + loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int) + if save_img: + mouth_img = img[loc_mouth[1]:loc_mouth[3], loc_mouth[0]:loc_mouth[2], :] + cv2.imwrite(f'tmp/{item_idx:08d}_mouth.png', mouth_img * 255) + + save_dict[f'{item_idx:08d}'] = item_dict + +print('Save...') +torch.save(save_dict, save_path) diff --git a/gfpgan/setup.cfg b/gfpgan/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..3d90d600476f24315855b73c777bd7571f42f954 --- /dev/null +++ b/gfpgan/setup.cfg @@ -0,0 +1,33 @@ +[flake8] +ignore = + # line break before binary operator (W503) + W503, + # line break after binary operator (W504) + W504, +max-line-length=120 + +[yapf] +based_on_style = pep8 +column_limit = 120 +blank_line_before_nested_class_or_def = true +split_before_expression_after_opening_paren = true + +[isort] +line_length = 120 +multi_line_output = 0 +known_standard_library = pkg_resources,setuptools +known_first_party = gfpgan +known_third_party = basicsr,cv2,facexlib,numpy,pytest,torch,torchvision,tqdm,yaml +no_lines_before = STDLIB,LOCALFOLDER +default_section = THIRDPARTY + +[codespell] +skip = .git,./docs/build +count = +quiet-level = 3 + +[aliases] +test=pytest + +[tool:pytest] +addopts=tests/ diff --git a/gfpgan/setup.py b/gfpgan/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..474e9188aa2dc5c19614921760ce4ad99bd19c13 --- /dev/null +++ b/gfpgan/setup.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import time + +version_file = 'gfpgan/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def get_requirements(filename='requirements.txt'): + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + write_version_py() + setup( + name='gfpgan', + version=get_version(), + description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan', + url='https://github.com/TencentARC/GFPGAN', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License Version 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + zip_safe=False) diff --git a/gfpgan/tests/data/ffhq_gt.lmdb/data.mdb b/gfpgan/tests/data/ffhq_gt.lmdb/data.mdb new file mode 100644 index 0000000000000000000000000000000000000000..823e0a9dae90d0699777770760ff012155974290 Binary files /dev/null and b/gfpgan/tests/data/ffhq_gt.lmdb/data.mdb differ diff --git a/gfpgan/tests/data/ffhq_gt.lmdb/lock.mdb b/gfpgan/tests/data/ffhq_gt.lmdb/lock.mdb new file mode 100644 index 0000000000000000000000000000000000000000..c53d2e56457060392f18d1dc7ab6574b15f42794 Binary files /dev/null and b/gfpgan/tests/data/ffhq_gt.lmdb/lock.mdb differ diff --git a/gfpgan/tests/data/ffhq_gt.lmdb/meta_info.txt b/gfpgan/tests/data/ffhq_gt.lmdb/meta_info.txt new file mode 100644 index 0000000000000000000000000000000000000000..8f18d95c03214990dbfd7e6ab520eb7b337038f2 --- /dev/null +++ b/gfpgan/tests/data/ffhq_gt.lmdb/meta_info.txt @@ -0,0 +1 @@ +00000000.png (512,512,3) 1 diff --git a/gfpgan/tests/data/gt/00000000.png b/gfpgan/tests/data/gt/00000000.png new file mode 100644 index 0000000000000000000000000000000000000000..33425aad207003300a8df43a3fe78dde492c552e Binary files /dev/null and b/gfpgan/tests/data/gt/00000000.png differ diff --git a/gfpgan/tests/data/test_eye_mouth_landmarks.pth b/gfpgan/tests/data/test_eye_mouth_landmarks.pth new file mode 100644 index 0000000000000000000000000000000000000000..a27f35286fecf9bf098033a57698690c0d3e8f8d --- /dev/null +++ b/gfpgan/tests/data/test_eye_mouth_landmarks.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:131583fca2cc346652f8754eb3c5a0bdeda808686039ff10ead7a26254b72358 +size 943 diff --git a/gfpgan/tests/data/test_ffhq_degradation_dataset.yml b/gfpgan/tests/data/test_ffhq_degradation_dataset.yml new file mode 100644 index 0000000000000000000000000000000000000000..df50c4bc5ca7f019cc8c47e1e39cd5709137fbee --- /dev/null +++ b/gfpgan/tests/data/test_ffhq_degradation_dataset.yml @@ -0,0 +1,24 @@ +name: UnitTest +type: FFHQDegradationDataset +dataroot_gt: tests/data/gt +io_backend: + type: disk + +use_hflip: true +mean: [0.5, 0.5, 0.5] +std: [0.5, 0.5, 0.5] +out_size: 512 + +blur_kernel_size: 41 +kernel_list: ['iso', 'aniso'] +kernel_prob: [0.5, 0.5] +blur_sigma: [0.1, 10] +downsample_range: [0.8, 8] +noise_range: [0, 20] +jpeg_range: [60, 100] + +# color jitter and gray +color_jitter_prob: 1 +color_jitter_shift: 20 +color_jitter_pt_prob: 1 +gray_prob: 1 diff --git a/gfpgan/tests/data/test_gfpgan_model.yml b/gfpgan/tests/data/test_gfpgan_model.yml new file mode 100644 index 0000000000000000000000000000000000000000..bac650ef201a383b3ae8c7f3ca3c76b2dbf9cbf1 --- /dev/null +++ b/gfpgan/tests/data/test_gfpgan_model.yml @@ -0,0 +1,140 @@ +num_gpu: 1 +manual_seed: 0 +is_train: True +dist: False + +# network structures +network_g: + type: GFPGANv1 + out_size: 512 + num_style_feat: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + decoder_load_path: ~ + fix_decoder: true + num_mlp: 8 + lr_mlp: 0.01 + input_is_latent: true + different_w: true + narrow: 0.5 + sft_half: true + +network_d: + type: StyleGAN2Discriminator + out_size: 512 + channel_multiplier: 1 + resample_kernel: [1, 3, 3, 1] + +network_d_left_eye: + type: FacialComponentDiscriminator + +network_d_right_eye: + type: FacialComponentDiscriminator + +network_d_mouth: + type: FacialComponentDiscriminator + +network_identity: + type: ResNetArcFace + block: IRBlock + layers: [2, 2, 2, 2] + use_se: False + +# path +path: + pretrain_network_g: ~ + param_key_g: params_ema + strict_load_g: ~ + pretrain_network_d: ~ + pretrain_network_d_left_eye: ~ + pretrain_network_d_right_eye: ~ + pretrain_network_d_mouth: ~ + pretrain_network_identity: ~ + # resume + resume_state: ~ + ignore_resume_networks: ['network_identity'] + +# training settings +train: + optim_g: + type: Adam + lr: !!float 2e-3 + optim_d: + type: Adam + lr: !!float 2e-3 + optim_component: + type: Adam + lr: !!float 2e-3 + + scheduler: + type: MultiStepLR + milestones: [600000, 700000] + gamma: 0.5 + + total_iter: 800000 + warmup_iter: -1 # no warm up + + # losses + # pixel loss + pixel_opt: + type: L1Loss + loss_weight: !!float 1e-1 + reduction: mean + # L1 loss used in pyramid loss, component style loss and identity loss + L1_opt: + type: L1Loss + loss_weight: 1 + reduction: mean + + # image pyramid loss + pyramid_loss_weight: 1 + remove_pyramid_loss: 50000 + # perceptual loss (content and style losses) + perceptual_opt: + type: PerceptualLoss + layer_weights: + # before relu + 'conv1_2': 0.1 + 'conv2_2': 0.1 + 'conv3_4': 1 + 'conv4_4': 1 + 'conv5_4': 1 + vgg_type: vgg19 + use_input_norm: true + perceptual_weight: !!float 1 + style_weight: 50 + range_norm: true + criterion: l1 + # gan loss + gan_opt: + type: GANLoss + gan_type: wgan_softplus + loss_weight: !!float 1e-1 + # r1 regularization for discriminator + r1_reg_weight: 10 + # facial component loss + gan_component_opt: + type: GANLoss + gan_type: vanilla + real_label_val: 1.0 + fake_label_val: 0.0 + loss_weight: !!float 1 + comp_style_weight: 200 + # identity loss + identity_weight: 10 + + net_d_iters: 1 + net_d_init_iters: 0 + net_d_reg_every: 1 + +# validation settings +val: + val_freq: !!float 5e3 + save_img: True + use_pbar: True + + metrics: + psnr: # metric name + type: calculate_psnr + crop_border: 0 + test_y_channel: false diff --git a/gfpgan/tests/test_arcface_arch.py b/gfpgan/tests/test_arcface_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b28d33800ae78a354e078e14373d2ee159dc7b --- /dev/null +++ b/gfpgan/tests/test_arcface_arch.py @@ -0,0 +1,49 @@ +import torch + +from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace + + +def test_resnetarcface(): + """Test arch: ResNetArcFace.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval() + img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda() + output = net(img) + assert output.shape == (1, 512) + + # -------------------- without SE block ----------------------- # + net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval() + output = net(img) + assert output.shape == (1, 512) + + +def test_basicblock(): + """Test the BasicBlock in arcface_arch""" + block = BasicBlock(1, 3, stride=1, downsample=None).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 3, 12, 12) + + # ----------------- use the downsmaple module--------------- # + downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() + block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 3, 6, 6) + + +def test_bottleneck(): + """Test the Bottleneck in arcface_arch""" + block = Bottleneck(1, 1, stride=1, downsample=None).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 4, 12, 12) + + # ----------------- use the downsmaple module--------------- # + downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda() + block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda() + img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda() + output = block(img) + assert output.shape == (1, 4, 6, 6) diff --git a/gfpgan/tests/test_ffhq_degradation_dataset.py b/gfpgan/tests/test_ffhq_degradation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fa56c03fb8e23df26aa6ed8442a86b3c676eec78 --- /dev/null +++ b/gfpgan/tests/test_ffhq_degradation_dataset.py @@ -0,0 +1,96 @@ +import pytest +import yaml + +from gfpgan.data.ffhq_degradation_dataset import FFHQDegradationDataset + + +def test_ffhq_degradation_dataset(): + + with open('tests/data/test_ffhq_degradation_dataset.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 1 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == 'tests/data/gt/00000000.png' + + # ------------------ test with probability = 0 -------------------- # + opt['color_jitter_prob'] = 0 + opt['color_jitter_pt_prob'] = 0 + opt['gray_prob'] = 0 + opt['io_backend'] = dict(type='disk') + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'disk' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 0 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == 'tests/data/gt/00000000.png' + + # ------------------ test lmdb backend -------------------- # + opt['dataroot_gt'] = 'tests/data/ffhq_gt.lmdb' + opt['io_backend'] = dict(type='lmdb') + + dataset = FFHQDegradationDataset(opt) + assert dataset.io_backend_opt['type'] == 'lmdb' # io backend + assert len(dataset) == 1 # whether to read correct meta info + assert dataset.kernel_list == ['iso', 'aniso'] # correct initialization the degradation configurations + assert dataset.color_jitter_prob == 0 + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == '00000000' + + # ------------------ test with crop_components -------------------- # + opt['crop_components'] = True + opt['component_path'] = 'tests/data/test_eye_mouth_landmarks.pth' + opt['eye_enlarge_ratio'] = 1.4 + opt['gt_gray'] = True + opt['io_backend'] = dict(type='lmdb') + + dataset = FFHQDegradationDataset(opt) + assert dataset.crop_components is True + + # test __getitem__ + result = dataset.__getitem__(0) + # check returned keys + expected_keys = ['gt', 'lq', 'gt_path', 'loc_left_eye', 'loc_right_eye', 'loc_mouth'] + assert set(expected_keys).issubset(set(result.keys())) + # check shape and contents + assert result['gt'].shape == (3, 512, 512) + assert result['lq'].shape == (3, 512, 512) + assert result['gt_path'] == '00000000' + assert result['loc_left_eye'].shape == (4, ) + assert result['loc_right_eye'].shape == (4, ) + assert result['loc_mouth'].shape == (4, ) + + # ------------------ lmdb backend should have paths ends with lmdb -------------------- # + with pytest.raises(ValueError): + opt['dataroot_gt'] = 'tests/data/gt' + opt['io_backend'] = dict(type='lmdb') + dataset = FFHQDegradationDataset(opt) diff --git a/gfpgan/tests/test_gfpgan_arch.py b/gfpgan/tests/test_gfpgan_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..cef14a435aa824a1b7c4baaf2d1fe0a2f6cc4441 --- /dev/null +++ b/gfpgan/tests/test_gfpgan_arch.py @@ -0,0 +1,203 @@ +import torch + +from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1, StyleGAN2GeneratorSFT +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean, StyleGAN2GeneratorCSFT + + +def test_stylegan2generatorsft(): + """Test arch: StyleGAN2GeneratorSFT.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorSFT( + out_size=32, + num_style_feat=512, + num_mlp=8, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + lr_mlp=0.01, + narrow=1, + sft_half=False).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda() + condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda() + condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda() + conditions = [condition1, condition1, condition2, condition2, condition3, condition3] + output = net([style], conditions) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], conditions, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], conditions, randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], conditions, truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + +def test_gfpganv1(): + """Test arch: GFPGANv1.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = GFPGANv1( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + # -------------------- with different_w = True ----------------------- # + net = GFPGANv1( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + resample_kernel=(1, 3, 3, 1), + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=True, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + +def test_facialcomponentdiscriminator(): + """Test arch: FacialComponentDiscriminator.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = FacialComponentDiscriminator().cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert len(output) == 2 + assert output[0].shape == (1, 1, 8, 8) + assert output[1] is None + + # -------------------- return intermediate features ----------------------- # + output = net(img, return_feats=True) + assert len(output) == 2 + assert output[0].shape == (1, 1, 8, 8) + assert len(output[1]) == 2 + assert output[1][0].shape == (1, 128, 16, 16) + assert output[1][1].shape == (1, 256, 8, 8) + + +def test_stylegan2generatorcsft(): + """Test arch: StyleGAN2GeneratorCSFT.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorCSFT( + out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=1, sft_half=False).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + condition1 = torch.rand((1, 512, 8, 8), dtype=torch.float32).cuda() + condition2 = torch.rand((1, 512, 16, 16), dtype=torch.float32).cuda() + condition3 = torch.rand((1, 512, 32, 32), dtype=torch.float32).cuda() + conditions = [condition1, condition1, condition2, condition2, condition3, condition3] + output = net([style], conditions) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], conditions, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], conditions, randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], conditions, truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + +def test_gfpganv1clean(): + """Test arch: GFPGANv1Clean.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = GFPGANv1Clean( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=True).cuda().eval() + + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) + + # -------------------- with different_w = True ----------------------- # + net = GFPGANv1Clean( + out_size=32, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + input_is_latent=False, + different_w=True, + narrow=1, + sft_half=True).cuda().eval() + img = torch.rand((1, 3, 32, 32), dtype=torch.float32).cuda() + output = net(img) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 3 + # check out_rgbs for intermediate loss + assert output[1][0].shape == (1, 3, 8, 8) + assert output[1][1].shape == (1, 3, 16, 16) + assert output[1][2].shape == (1, 3, 32, 32) diff --git a/gfpgan/tests/test_gfpgan_model.py b/gfpgan/tests/test_gfpgan_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1408ddd7c909c7257fbcea79f8576231a40f9211 --- /dev/null +++ b/gfpgan/tests/test_gfpgan_model.py @@ -0,0 +1,132 @@ +import tempfile +import torch +import yaml +from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator +from basicsr.data.paired_image_dataset import PairedImageDataset +from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss + +from gfpgan.archs.arcface_arch import ResNetArcFace +from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1 +from gfpgan.models.gfpgan_model import GFPGANModel + + +def test_gfpgan_model(): + with open('tests/data/test_gfpgan_model.yml', mode='r') as f: + opt = yaml.load(f, Loader=yaml.FullLoader) + + # build model + model = GFPGANModel(opt) + # test attributes + assert model.__class__.__name__ == 'GFPGANModel' + assert isinstance(model.net_g, GFPGANv1) # generator + assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator + # facial component discriminators + assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator) + assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator) + assert isinstance(model.net_d_mouth, FacialComponentDiscriminator) + # identity network + assert isinstance(model.network_identity, ResNetArcFace) + # losses + assert isinstance(model.cri_pix, L1Loss) + assert isinstance(model.cri_perceptual, PerceptualLoss) + assert isinstance(model.cri_gan, GANLoss) + assert isinstance(model.cri_l1, L1Loss) + # optimizer + assert isinstance(model.optimizers[0], torch.optim.Adam) + assert isinstance(model.optimizers[1], torch.optim.Adam) + + # prepare data + gt = torch.rand((1, 3, 512, 512), dtype=torch.float32) + lq = torch.rand((1, 3, 512, 512), dtype=torch.float32) + loc_left_eye = torch.rand((1, 4), dtype=torch.float32) + loc_right_eye = torch.rand((1, 4), dtype=torch.float32) + loc_mouth = torch.rand((1, 4), dtype=torch.float32) + data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth) + model.feed_data(data) + # check data shape + assert model.lq.shape == (1, 3, 512, 512) + assert model.gt.shape == (1, 3, 512, 512) + assert model.loc_left_eyes.shape == (1, 4) + assert model.loc_right_eyes.shape == (1, 4) + assert model.loc_mouths.shape == (1, 4) + + # ----------------- test optimize_parameters -------------------- # + model.feed_data(data) + model.optimize_parameters(1) + assert model.output.shape == (1, 3, 512, 512) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = [ + 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', + 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', + 'l_d_right_eye', 'l_d_mouth' + ] + assert set(expected_keys).issubset(set(model.log_dict.keys())) + + # ----------------- remove pyramid_loss_weight-------------------- # + model.feed_data(data) + model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000 + assert model.output.shape == (1, 3, 512, 512) + assert isinstance(model.log_dict, dict) + # check returned keys + expected_keys = [ + 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth', + 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye', + 'l_d_right_eye', 'l_d_mouth' + ] + assert set(expected_keys).issubset(set(model.log_dict.keys())) + + # ----------------- test save -------------------- # + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['path']['models'] = tmpdir + model.opt['path']['training_states'] = tmpdir + model.save(0, 1) + + # ----------------- test the test function -------------------- # + model.test() + assert model.output.shape == (1, 3, 512, 512) + # delete net_g_ema + model.__delattr__('net_g_ema') + model.test() + assert model.output.shape == (1, 3, 512, 512) + assert model.net_g.training is True # should back to training mode after testing + + # ----------------- test nondist_validation -------------------- # + # construct dataloader + dataset_opt = dict( + name='Demo', + dataroot_gt='tests/data/gt', + dataroot_lq='tests/data/gt', + io_backend=dict(type='disk'), + scale=4, + phase='val') + dataset = PairedImageDataset(dataset_opt) + dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + assert model.is_train is True + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['path']['visualization'] = tmpdir + model.nondist_validation(dataloader, 1, None, save_img=True) + assert model.is_train is True + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) + + # validation + with tempfile.TemporaryDirectory() as tmpdir: + model.opt['is_train'] = False + model.opt['val']['suffix'] = 'test' + model.opt['path']['visualization'] = tmpdir + model.opt['val']['pbar'] = True + model.nondist_validation(dataloader, 1, None, save_img=True) + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) + + # if opt['val']['suffix'] is None + model.opt['val']['suffix'] = None + model.opt['name'] = 'demo' + model.opt['path']['visualization'] = tmpdir + model.nondist_validation(dataloader, 1, None, save_img=True) + # check metric_results + assert 'psnr' in model.metric_results + assert isinstance(model.metric_results['psnr'], float) diff --git a/gfpgan/tests/test_stylegan2_clean_arch.py b/gfpgan/tests/test_stylegan2_clean_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..78bb920e73ce28cfec9ea89a4339cc5b87981b47 --- /dev/null +++ b/gfpgan/tests/test_stylegan2_clean_arch.py @@ -0,0 +1,52 @@ +import torch + +from gfpgan.archs.stylegan2_clean_arch import StyleGAN2GeneratorClean + + +def test_stylegan2generatorclean(): + """Test arch: StyleGAN2GeneratorClean.""" + + # model init and forward (gpu) + if torch.cuda.is_available(): + net = StyleGAN2GeneratorClean( + out_size=32, num_style_feat=512, num_mlp=8, channel_multiplier=1, narrow=0.5).cuda().eval() + style = torch.rand((1, 512), dtype=torch.float32).cuda() + output = net([style], input_is_latent=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with return_latents ----------------------- # + output = net([style], input_is_latent=True, return_latents=True) + assert output[0].shape == (1, 3, 32, 32) + assert len(output[1]) == 1 + # check latent + assert output[1][0].shape == (8, 512) + + # -------------------- with randomize_noise = False ----------------------- # + output = net([style], randomize_noise=False) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # -------------------- with truncation = 0.5 and mixing----------------------- # + output = net([style, style], truncation=0.5, truncation_latent=style) + assert output[0].shape == (1, 3, 32, 32) + assert output[1] is None + + # ------------------ test make_noise ----------------------- # + out = net.make_noise() + assert len(out) == 7 + assert out[0].shape == (1, 1, 4, 4) + assert out[1].shape == (1, 1, 8, 8) + assert out[2].shape == (1, 1, 8, 8) + assert out[3].shape == (1, 1, 16, 16) + assert out[4].shape == (1, 1, 16, 16) + assert out[5].shape == (1, 1, 32, 32) + assert out[6].shape == (1, 1, 32, 32) + + # ------------------ test get_latent ----------------------- # + out = net.get_latent(style) + assert out.shape == (1, 512) + + # ------------------ test mean_latent ----------------------- # + out = net.mean_latent(2) + assert out.shape == (1, 512) diff --git a/gfpgan/tests/test_utils.py b/gfpgan/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a963b3269dea05f9b7ec6c3db016e9a579c92fc8 --- /dev/null +++ b/gfpgan/tests/test_utils.py @@ -0,0 +1,43 @@ +import cv2 +from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +from gfpgan.archs.gfpganv1_arch import GFPGANv1 +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean +from gfpgan.utils import GFPGANer + + +def test_gfpganer(): + # initialize with the clean model + restorer = GFPGANer( + model_path='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth', + upscale=2, + arch='clean', + channel_multiplier=2, + bg_upsampler=None) + # test attribute + assert isinstance(restorer.gfpgan, GFPGANv1Clean) + assert isinstance(restorer.face_helper, FaceRestoreHelper) + + # initialize with the original model + restorer = GFPGANer( + model_path='experiments/pretrained_models/GFPGANv1.pth', + upscale=2, + arch='original', + channel_multiplier=1, + bg_upsampler=None) + # test attribute + assert isinstance(restorer.gfpgan, GFPGANv1) + assert isinstance(restorer.face_helper, FaceRestoreHelper) + + # ------------------ test enhance ---------------- # + img = cv2.imread('tests/data/gt/00000000.png', cv2.IMREAD_COLOR) + result = restorer.enhance(img, has_aligned=False, paste_back=True) + assert result[0][0].shape == (512, 512, 3) + assert result[1][0].shape == (512, 512, 3) + assert result[2].shape == (1024, 1024, 3) + + # with has_aligned=True + result = restorer.enhance(img, has_aligned=True, paste_back=False) + assert result[0][0].shape == (512, 512, 3) + assert result[1][0].shape == (512, 512, 3) + assert result[2] is None diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b9429c145e2af585d085c1f7e0ff064866aadc01 --- /dev/null +++ b/main.py @@ -0,0 +1,27 @@ +import os +from argparse import Namespace +from +basePath = os.path.dirname(os.path.realpath(__file__)) +wav2lipFolderName = 'wav2lip' +gfpganFolderName = 'gfpgan' +wav2lipPath = os.path.join(basePath, wav2lipFolderName) +gfpganPath = os.path.join(basePath, gfpganFolderName) + +outputPath = basePath + '/outputs' +inputAudioPath = basePath + '/inputs/kim_audio.mp3' +inputVideoPath = basePath + '/inputs/kimk_7s_raw.mp4' +lipSyncedOutputPath = basePath + '/outputs/result.mp4' +if not os.path.exists(outputPath): + os.makedirs(outputPath) + +if __name__ == '__main__': + args = Namespace(checkpoint_path='checkpoints/wav2lip.pth', + face=inputVideoPath, + audio=inputAudioPath, + outfile='results/result_voice.mp4', + static=False, + fps=25.0, pads=[0, 10, 0, 0], face_det_batch_size=16, wav2lip_batch_size=128, resize_factor=1, + crop=[0, -1, 0, -1], box=[-1, -1, -1, -1], rotate=False, nosmooth=False, img_size=96) + + print(inputAudioPath, os.path.isdir(inputAudioPath)) + print(inputVideoPath, os.path.isdir(inputVideoPath)) diff --git a/references.txt b/references.txt new file mode 100644 index 0000000000000000000000000000000000000000..2aae30fb07c34e4c5c02fd1d672e15ef9779612c --- /dev/null +++ b/references.txt @@ -0,0 +1 @@ +https://www.youtube.com/watch?v=LQCQym6hVMo \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f96372e96f8264ea36bd18b3885f9a88b006e56a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +librosa==0.10.0 +numpy<1.24.1 +opencv-contrib-python>=4.2.0.34 +opencv-python==4.7.0.72 +torch>=1.7 +torchvision>=>=0.8.2 +tqdm==4.48 +numba==0.56.4 + +basicsr>=1.3.4.0 +facexlib>=0.2.3 +lmdb +pyyaml +scipy +tb-nightly +yapf +realesrgan + +ffmpeg +gradio \ No newline at end of file diff --git a/wav2lip/.gitignore b/wav2lip/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..56f02807a55be5c6a9c2c9cf325963dd71ea535f --- /dev/null +++ b/wav2lip/.gitignore @@ -0,0 +1,16 @@ +*.pkl +*.jpg +*.mp4 +*.pth +*.pyc +__pycache__ +*.h5 +*.avi +*.wav +filelists/*.txt +evaluation/test_filelists/lr*.txt +*.pyc +*.mkv +*.gif +*.webm +*.mp3 diff --git a/wav2lip/README.md b/wav2lip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..35f69daff0ee31e44c7b251f360365b2710585fb --- /dev/null +++ b/wav2lip/README.md @@ -0,0 +1,152 @@ +# **Wav2Lip**: *Accurately Lip-syncing Videos In The Wild* + +For commercial requests, please contact us at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model ready that can be used commercially. + +This code is part of the paper: _A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild_ published at ACM Multimedia 2020. + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs2)](https://paperswithcode.com/sota/lip-sync-on-lrs2?p=a-lip-sync-expert-is-all-you-need-for-speech) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrs3)](https://paperswithcode.com/sota/lip-sync-on-lrs3?p=a-lip-sync-expert-is-all-you-need-for-speech) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-lip-sync-expert-is-all-you-need-for-speech/lip-sync-on-lrw)](https://paperswithcode.com/sota/lip-sync-on-lrw?p=a-lip-sync-expert-is-all-you-need-for-speech) + +|📑 Original Paper|📰 Project Page|🌀 Demo|⚡ Live Testing|📔 Colab Notebook +|:-:|:-:|:-:|:-:|:-:| +[Paper](http://arxiv.org/abs/2008.10010) | [Project Page](http://cvit.iiit.ac.in/research/projects/cvit-projects/a-lip-sync-expert-is-all-you-need-for-speech-to-lip-generation-in-the-wild/) | [Demo Video](https://youtu.be/0fXaDCZNOJc) | [Interactive Demo](https://bhaasha.iiit.ac.in/lipsync) | [Colab Notebook](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing) /[Updated Collab Notebook](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH) + + + +---------- +**Highlights** +---------- + - Weights of the visual quality disc has been updated in readme! + - Lip-sync videos to any target speech with high accuracy :100:. Try our [interactive demo](https://bhaasha.iiit.ac.in/lipsync). + - :sparkles: Works for any identity, voice, and language. Also works for CGI faces and synthetic voices. + - Complete training code, inference code, and pretrained models are available :boom: + - Or, quick-start with the Google Colab Notebook: [Link](https://colab.research.google.com/drive/1tZpDWXz49W6wDcTprANRGLo2D_EbD5J8?usp=sharing). Checkpoints and samples are available in a Google Drive [folder](https://drive.google.com/drive/folders/1I-0dNLfFOSFwrfqjNa-SXuwaURHE5K4k?usp=sharing) as well. There is also a [tutorial video](https://www.youtube.com/watch?v=Ic0TBhfuOrA) on this, courtesy of [What Make Art](https://www.youtube.com/channel/UCmGXH-jy0o2CuhqtpxbaQgA). Also, thanks to [Eyal Gruss](https://eyalgruss.com), there is a more accessible [Google Colab notebook](https://j.mp/wav2lip) with more useful features. A tutorial collab notebook is present at this [link](https://colab.research.google.com/drive/1IjFW1cLevs6Ouyu4Yht4mnR4yeuMqO7Y#scrollTo=MH1m608OymLH). + - :fire: :fire: Several new, reliable evaluation benchmarks and metrics [[`evaluation/` folder of this repo]](https://github.com/Rudrabha/Wav2Lip/tree/master/evaluation) released. Instructions to calculate the metrics reported in the paper are also present. + +-------- +**Disclaimer** +-------- +All results from this open-source code or our [demo website](https://bhaasha.iiit.ac.in/lipsync) should only be used for research/academic/personal purposes only. As the models are trained on the LRS2 dataset, any form of commercial use is strictly prohibhited. For commercial requests please contact us directly! + +Prerequisites +------------- +- `Python 3.6` +- ffmpeg: `sudo apt-get install ffmpeg` +- Install necessary packages using `pip install -r requirements.txt`. Alternatively, instructions for using a docker image is provided [here](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668). Have a look at [this comment](https://github.com/Rudrabha/Wav2Lip/issues/131#issuecomment-725478562) and comment on [the gist](https://gist.github.com/xenogenesi/e62d3d13dadbc164124c830e9c453668) if you encounter any issues. +- Face detection [pre-trained model](https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth) should be downloaded to `face_detection/detection/sfd/s3fd.pth`. Alternative [link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/prajwal_k_research_iiit_ac_in/EZsy6qWuivtDnANIG73iHjIBjMSoojcIV0NULXV-yiuiIg?e=qTasa8) if the above does not work. + +Getting the weights +---------- +| Model | Description | Link to the model | +| :-------------: | :---------------: | :---------------: | +| Wav2Lip | Highly accurate lip-sync | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/Eb3LEzbfuKlJiR600lQWRxgBIY27JZg80f7V9jtMfbNDaQ?e=TBFBVW) | +| Wav2Lip + GAN | Slightly inferior lip-sync, but better visual quality | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EdjI7bZlgApMqsVoEUUXpLsBxqXbn5z8VTmoxp55YNDcIA?e=n9ljGW) | +| Expert Discriminator | Weights of the expert discriminator | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQRvmiZg-HRAjvI6zqN9eTEBP74KefynCwPWVmF57l-AYA?e=ZRPHKP) | +| Visual Quality Discriminator | Weights of the visual disc trained in a GAN setup | [Link](https://iiitaphyd-my.sharepoint.com/:u:/g/personal/radrabha_m_research_iiit_ac_in/EQVqH88dTm1HjlK11eNba5gBbn15WMS0B0EZbDBttqrqkg?e=ic0ljo) | + +Lip-syncing videos using the pre-trained models (Inference) +------- +You can lip-sync any video to any audio: +```bash +python inference.py --checkpoint_path --face --audio +``` +The result is saved (by default) in `results/result_voice.mp4`. You can specify it as an argument, similar to several other available options. The audio source can be any file supported by `FFMPEG` containing audio data: `*.wav`, `*.mp3` or even a video file, from which the code will automatically extract the audio. + +##### Tips for better results: +- Experiment with the `--pads` argument to adjust the detected face bounding box. Often leads to improved results. You might need to increase the bottom padding to include the chin region. E.g. `--pads 0 20 0 0`. +- If you see the mouth position dislocated or some weird artifacts such as two mouths, then it can be because of over-smoothing the face detections. Use the `--nosmooth` argument and give another try. +- Experiment with the `--resize_factor` argument, to get a lower resolution video. Why? The models are trained on faces which were at a lower resolution. You might get better, visually pleasing results for 720p videos than for 1080p videos (in many cases, the latter works well too). +- The Wav2Lip model without GAN usually needs more experimenting with the above two to get the most ideal results, and sometimes, can give you a better result as well. + +Preparing LRS2 for training +---------- +Our models are trained on LRS2. See [here](#training-on-datasets-other-than-lrs2) for a few suggestions regarding training on other datasets. +##### LRS2 dataset folder structure + +``` +data_root (mvlrs_v1) +├── main, pretrain (we use only main folder in this work) +| ├── list of folders +| │ ├── five-digit numbered video IDs ending with (.mp4) +``` + +Place the LRS2 filelists (train, val, test) `.txt` files in the `filelists/` folder. + +##### Preprocess the dataset for fast training + +```bash +python preprocess.py --data_root data_root/main --preprocessed_root lrs2_preprocessed/ +``` +Additional options like `batch_size` and number of GPUs to use in parallel to use can also be set. + +##### Preprocessed LRS2 folder structure +``` +preprocessed_root (lrs2_preprocessed) +├── list of folders +| ├── Folders with five-digit numbered video IDs +| │ ├── *.jpg +| │ ├── audio.wav +``` + +Train! +---------- +There are two major steps: (i) Train the expert lip-sync discriminator, (ii) Train the Wav2Lip model(s). + +##### Training the expert discriminator +You can download [the pre-trained weights](#getting-the-weights) if you want to skip this step. To train it: +```bash +python color_syncnet_train.py --data_root lrs2_preprocessed/ --checkpoint_dir +``` +##### Training the Wav2Lip models +You can either train the model without the additional visual quality disriminator (< 1 day of training) or use the discriminator (~2 days). For the former, run: +```bash +python wav2lip_train.py --data_root lrs2_preprocessed/ --checkpoint_dir --syncnet_checkpoint_path +``` + +To train with the visual quality discriminator, you should run `hq_wav2lip_train.py` instead. The arguments for both the files are similar. In both the cases, you can resume training as well. Look at `python wav2lip_train.py --help` for more details. You can also set additional less commonly-used hyper-parameters at the bottom of the `hparams.py` file. + +Training on datasets other than LRS2 +------------------------------------ +Training on other datasets might require modifications to the code. Please read the following before you raise an issue: + +- You might not get good results by training/fine-tuning on a few minutes of a single speaker. This is a separate research problem, to which we do not have a solution yet. Thus, we would most likely not be able to resolve your issue. +- You must train the expert discriminator for your own dataset before training Wav2Lip. +- If it is your own dataset downloaded from the web, in most cases, needs to be sync-corrected. +- Be mindful of the FPS of the videos of your dataset. Changes to FPS would need significant code changes. +- The expert discriminator's eval loss should go down to ~0.25 and the Wav2Lip eval sync loss should go down to ~0.2 to get good results. + +When raising an issue on this topic, please let us know that you are aware of all these points. + +We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. + +Evaluation +---------- +Please check the `evaluation/` folder for the instructions. + +License and Citation +---------- +Theis repository can only be used for personal/research/non-commercial purposes. However, for commercial requests, please contact us directly at radrabha.m@research.iiit.ac.in or prajwal.k@research.iiit.ac.in. We have an HD model trained on a dataset allowing commercial usage. The size of the generated face will be 192 x 288 in our new model. Please cite the following paper if you use this repository: +``` +@inproceedings{10.1145/3394171.3413532, +author = {Prajwal, K R and Mukhopadhyay, Rudrabha and Namboodiri, Vinay P. and Jawahar, C.V.}, +title = {A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild}, +year = {2020}, +isbn = {9781450379885}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +url = {https://doi.org/10.1145/3394171.3413532}, +doi = {10.1145/3394171.3413532}, +booktitle = {Proceedings of the 28th ACM International Conference on Multimedia}, +pages = {484–492}, +numpages = {9}, +keywords = {lip sync, talking face generation, video generation}, +location = {Seattle, WA, USA}, +series = {MM '20} +} +``` + + +Acknowledgements +---------- +Parts of the code structure is inspired by this [TTS repository](https://github.com/r9y9/deepvoice3_pytorch). We thank the author for this wonderful code. The code for Face Detection has been taken from the [face_alignment](https://github.com/1adrianb/face-alignment) repository. We thank the authors for releasing their code and models. We thank [zabique](https://github.com/zabique) for the tutorial collab notebook. diff --git a/wav2lip/audio.py b/wav2lip/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..32ab5fabe8e505baa7431f7fb81e367aed1d0ac3 --- /dev/null +++ b/wav2lip/audio.py @@ -0,0 +1,136 @@ +import librosa +import librosa.filters +import numpy as np +# import tensorflow as tf +from scipy import signal +from scipy.io import wavfile +from hparams import hparams as hp + +def load_wav(path, sr): + return librosa.core.load(path, sr=sr)[0] + +def save_wav(wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + +def save_wavenet_wav(wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + +def preemphasis(wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + +def inv_preemphasis(wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + +def get_hop_size(): + hop_size = hp.hop_size + if hop_size is None: + assert hp.frame_shift_ms is not None + hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) + return hop_size + +def linearspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(np.abs(D)) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def melspectrogram(wav): + D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) + S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db + + if hp.signal_normalization: + return _normalize(S) + return S + +def _lws_processor(): + import lws + return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + +def _stft(y): + if hp.use_lws: + return _lws_processor(hp).stft(y).T + else: + return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + +########################################################## +#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +def num_frames(length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + +def pad_lr(x, fsize, fshift): + """Compute left and right padding + """ + M = num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r +########################################################## +#Librosa correct padding +def librosa_pad_lr(x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + +# Conversions +_mel_basis = None + +def _linear_to_mel(spectogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectogram) + +def _build_mel_basis(): + assert hp.fmax <= hp.sample_rate // 2 + return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, + fmin=hp.fmin, fmax=hp.fmax) + +def _amp_to_db(x): + min_level = np.exp(hp.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + +def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + +def _normalize(S): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, + -hp.max_abs_value, hp.max_abs_value) + else: + return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) + + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 + if hp.symmetric_mels: + return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value + else: + return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + +def _denormalize(D): + if hp.allow_clipping_in_normalization: + if hp.symmetric_mels: + return (((np.clip(D, -hp.max_abs_value, + hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + + hp.min_level_db) + else: + return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) + + if hp.symmetric_mels: + return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) + else: + return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) diff --git a/wav2lip/checkpoints/README.md b/wav2lip/checkpoints/README.md new file mode 100644 index 0000000000000000000000000000000000000000..80258ec8fb8e6fdce46f3d420bad25b58cd2ee12 --- /dev/null +++ b/wav2lip/checkpoints/README.md @@ -0,0 +1 @@ +Place all your checkpoints (.pth files) here. \ No newline at end of file diff --git a/wav2lip/color_syncnet_train.py b/wav2lip/color_syncnet_train.py new file mode 100644 index 0000000000000000000000000000000000000000..afa00544386cb9627f0d899476abbc82b37958ed --- /dev/null +++ b/wav2lip/color_syncnet_train.py @@ -0,0 +1,279 @@ +from os.path import dirname, join, basename, isfile +from tqdm import tqdm + +from models import SyncNet_color as SyncNet +import audio + +import torch +from torch import nn +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +import numpy as np + +from glob import glob + +import os, random, cv2, argparse +from hparams import hparams, get_image_list + +parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator') + +parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True) + +parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) +parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str) + +args = parser.parse_args() + + +global_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +print('use_cuda: {}'.format(use_cuda)) + +syncnet_T = 5 +syncnet_mel_step_size = 16 + +class Dataset(object): + def __init__(self, split): + self.all_videos = get_image_list(args.data_root, split) + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + syncnet_T): + frame = join(vidname, '{}.jpg'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def crop_audio_window(self, spec, start_frame): + # num_frames = (T x hop_size * fps) / sample_rate + start_frame_num = self.get_frame_id(start_frame) + start_idx = int(80. * (start_frame_num / float(hparams.fps))) + + end_idx = start_idx + syncnet_mel_step_size + + return spec[start_idx : end_idx, :] + + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + vidname = self.all_videos[idx] + + img_names = list(glob(join(vidname, '*.jpg'))) + if len(img_names) <= 3 * syncnet_T: + continue + img_name = random.choice(img_names) + wrong_img_name = random.choice(img_names) + while wrong_img_name == img_name: + wrong_img_name = random.choice(img_names) + + if random.choice([True, False]): + y = torch.ones(1).float() + chosen = img_name + else: + y = torch.zeros(1).float() + chosen = wrong_img_name + + window_fnames = self.get_window(chosen) + if window_fnames is None: + continue + + window = [] + all_read = True + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + all_read = False + break + try: + img = cv2.resize(img, (hparams.img_size, hparams.img_size)) + except Exception as e: + all_read = False + break + + window.append(img) + + if not all_read: continue + + try: + wavpath = join(vidname, "audio.wav") + wav = audio.load_wav(wavpath, hparams.sample_rate) + + orig_mel = audio.melspectrogram(wav).T + except Exception as e: + continue + + mel = self.crop_audio_window(orig_mel.copy(), img_name) + + if (mel.shape[0] != syncnet_mel_step_size): + continue + + # H x W x 3 * T + x = np.concatenate(window, axis=2) / 255. + x = x.transpose(2, 0, 1) + x = x[:, x.shape[1]//2:] + + x = torch.FloatTensor(x) + mel = torch.FloatTensor(mel.T).unsqueeze(0) + + return x, mel, y + +logloss = nn.BCELoss() +def cosine_loss(a, v, y): + d = nn.functional.cosine_similarity(a, v) + loss = logloss(d.unsqueeze(1), y) + + return loss + +def train(device, model, train_data_loader, test_data_loader, optimizer, + checkpoint_dir=None, checkpoint_interval=None, nepochs=None): + + global global_step, global_epoch + resumed_step = global_step + + while global_epoch < nepochs: + running_loss = 0. + prog_bar = tqdm(enumerate(train_data_loader)) + for step, (x, mel, y) in prog_bar: + model.train() + optimizer.zero_grad() + + # Transform data to CUDA device + x = x.to(device) + + mel = mel.to(device) + + a, v = model(mel, x) + y = y.to(device) + + loss = cosine_loss(a, v, y) + loss.backward() + optimizer.step() + + global_step += 1 + cur_session_steps = global_step - resumed_step + running_loss += loss.item() + + if global_step == 1 or global_step % checkpoint_interval == 0: + save_checkpoint( + model, optimizer, global_step, checkpoint_dir, global_epoch) + + if global_step % hparams.syncnet_eval_interval == 0: + with torch.no_grad(): + eval_model(test_data_loader, global_step, device, model, checkpoint_dir) + + prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1))) + + global_epoch += 1 + +def eval_model(test_data_loader, global_step, device, model, checkpoint_dir): + eval_steps = 1400 + print('Evaluating for {} steps'.format(eval_steps)) + losses = [] + while 1: + for step, (x, mel, y) in enumerate(test_data_loader): + + model.eval() + + # Transform data to CUDA device + x = x.to(device) + + mel = mel.to(device) + + a, v = model(mel, x) + y = y.to(device) + + loss = cosine_loss(a, v, y) + losses.append(loss.item()) + + if step > eval_steps: break + + averaged_loss = sum(losses) / len(losses) + print(averaged_loss) + + return + +def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): + + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + +def load_checkpoint(path, model, optimizer, reset_optimizer=False): + global global_step + global global_epoch + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + model.load_state_dict(checkpoint["state_dict"]) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + + return model + +if __name__ == "__main__": + checkpoint_dir = args.checkpoint_dir + checkpoint_path = args.checkpoint_path + + if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) + + # Dataset and Dataloader setup + train_dataset = Dataset('train') + test_dataset = Dataset('val') + + train_data_loader = data_utils.DataLoader( + train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True, + num_workers=hparams.num_workers) + + test_data_loader = data_utils.DataLoader( + test_dataset, batch_size=hparams.syncnet_batch_size, + num_workers=8) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = SyncNet().to(device) + print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) + + optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], + lr=hparams.syncnet_lr) + + if checkpoint_path is not None: + load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False) + + train(device, model, train_data_loader, test_data_loader, optimizer, + checkpoint_dir=checkpoint_dir, + checkpoint_interval=hparams.syncnet_checkpoint_interval, + nepochs=hparams.nepochs) diff --git a/wav2lip/evaluation/README.md b/wav2lip/evaluation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..affebbc063571f576a05af4367c6ab6023299c02 --- /dev/null +++ b/wav2lip/evaluation/README.md @@ -0,0 +1,63 @@ +# Novel Evaluation Framework, new filelists, and using the LSE-D and LSE-C metric. + +Our paper also proposes a novel evaluation framework (Section 4). To evaluate on LRS2, LRS3, and LRW, the filelists are present in the `test_filelists` folder. Please use `gen_videos_from_filelist.py` script to generate the videos. After that, you can calculate the LSE-D and LSE-C scores using the instructions below. Please see [this thread](https://github.com/Rudrabha/Wav2Lip/issues/22#issuecomment-712825380) on how to calculate the FID scores. + +The videos of the ReSyncED benchmark for real-world evaluation will be released soon. + +### Steps to set-up the evaluation repository for LSE-D and LSE-C metric: +We use the pre-trained syncnet model available in this [repository](https://github.com/joonson/syncnet_python). + +* Clone the SyncNet repository. +``` +git clone https://github.com/joonson/syncnet_python.git +``` +* Follow the procedure given in the above linked [repository](https://github.com/joonson/syncnet_python) to download the pretrained models and set up the dependencies. + * **Note: Please install a separate virtual environment for the evaluation scripts. The versions used by Wav2Lip and the publicly released code of SyncNet is different and can cause version mis-match issues. To avoid this, we suggest the users to install a separate virtual environment for the evaluation scripts** +``` +cd syncnet_python +pip install -r requirements.txt +sh download_model.sh +``` +* The above step should ensure that all the dependencies required by the repository is installed and the pre-trained models are downloaded. + +### Running the evaluation scripts: +* Copy our evaluation scripts given in this folder to the cloned repository. +``` + cd Wav2Lip/evaluation/scores_LSE/ + cp *.py syncnet_python/ + cp *.sh syncnet_python/ +``` +**Note: We will release the test filelists for LRW, LRS2 and LRS3 shortly once we receive permission from the dataset creators. We will also release the Real World Dataset we have collected shortly.** + +* Our evaluation technique does not require ground-truth of any sorts. Given lip-synced videos we can directly calculate the scores from only the generated videos. Please store the generated videos (from our test sets or your own generated videos) in the following folder structure. +``` +video data root (Folder containing all videos) +├── All .mp4 files +``` +* Change the folder back to the cloned repository. +``` +cd syncnet_python +``` +* To run evaluation on the LRW, LRS2 and LRS3 test files, please run the following command: +``` +python calculate_scores_LRS.py --data_root /path/to/video/data/root --tmp_dir tmp_dir/ +``` + +* To run evaluation on the ReSynced dataset or your own generated videos, please run the following command: +``` +sh calculate_scores_real_videos.sh /path/to/video/data/root +``` +* The generated scores will be present in the all_scores.txt generated in the ```syncnet_python/``` folder + +# Evaluation of image quality using FID metric. +We use the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository for calculating the FID metrics. We dump all the frames in both ground-truth and generated videos and calculate the FID score. + + +# Opening issues related to evaluation scripts +* Please open the issues with the "Evaluation" label if you face any issues in the evaluation scripts. + +# Acknowledgements +Our evaluation pipeline in based on two existing repositories. LSE metrics are based on the [syncnet_python](https://github.com/joonson/syncnet_python) repository and the FID score is based on [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository. We thank the authors of both the repositories for releasing their wonderful code. + + + diff --git a/wav2lip/evaluation/gen_videos_from_filelist.py b/wav2lip/evaluation/gen_videos_from_filelist.py new file mode 100644 index 0000000000000000000000000000000000000000..bd666b93258d3da3143a63da742265ebeac2a8a3 --- /dev/null +++ b/wav2lip/evaluation/gen_videos_from_filelist.py @@ -0,0 +1,238 @@ +from os import listdir, path +import numpy as np +import scipy, cv2, os, sys, argparse +import dlib, json, subprocess +from tqdm import tqdm +from glob import glob +import torch + +sys.path.append('../') +import audio +import face_detection +from models import Wav2Lip + +parser = argparse.ArgumentParser(description='Code to generate results for test filelists') + +parser.add_argument('--filelist', type=str, + help='Filepath of filelist file to read', required=True) +parser.add_argument('--results_dir', type=str, help='Folder to save all results into', + required=True) +parser.add_argument('--data_root', type=str, required=True) +parser.add_argument('--checkpoint_path', type=str, + help='Name of saved checkpoint to load weights from', required=True) + +parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0], + help='Padding (top, bottom, left, right)') +parser.add_argument('--face_det_batch_size', type=int, + help='Single GPU batch size for face detection', default=64) +parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128) + +# parser.add_argument('--resize_factor', default=1, type=int) + +args = parser.parse_args() +args.img_size = 96 + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + +def face_detect(images): + batch_size = args.face_det_batch_size + + while 1: + predictions = [] + try: + for i in range(0, len(images), batch_size): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU') + batch_size //= 2 + args.face_det_batch_size = batch_size + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = args.pads + for rect, image in zip(predictions, images): + if rect is None: + raise ValueError('Face not detected!') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = get_smoothened_boxes(np.array(results), T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + return results + +def datagen(frames, face_det_results, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + for i, m in enumerate(mels): + if i >= len(frames): raise ValueError('Equal or less lengths only') + + frame_to_save = frames[i].copy() + face, coords, valid_frame = face_det_results[i].copy() + if not valid_frame: + continue + + face = cv2.resize(face, (args.img_size, args.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= args.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + +fps = 25 +mel_step_size = 16 +mel_idx_multiplier = 80./fps +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + +detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() + +model = load_model(args.checkpoint_path) + +def main(): + assert args.data_root is not None + data_root = args.data_root + + if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) + + with open(args.filelist, 'r') as filelist: + lines = filelist.readlines() + + for idx, line in enumerate(tqdm(lines)): + audio_src, video = line.strip().split() + + audio_src = os.path.join(data_root, audio_src) + '.mp4' + video = os.path.join(data_root, video) + '.mp4' + + command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav') + subprocess.call(command, shell=True) + temp_audio = '../temp/temp.wav' + + wav = audio.load_wav(temp_audio, 16000) + mel = audio.melspectrogram(wav) + if np.isnan(mel.reshape(-1)).sum() > 0: + continue + + mel_chunks = [] + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 + + video_stream = cv2.VideoCapture(video) + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading or len(full_frames) > len(mel_chunks): + video_stream.release() + break + full_frames.append(frame) + + if len(full_frames) < len(mel_chunks): + continue + + full_frames = full_frames[:len(mel_chunks)] + + try: + face_det_results = face_detect(full_frames.copy()) + except ValueError as e: + continue + + batch_size = args.wav2lip_batch_size + gen = datagen(full_frames.copy(), face_det_results, mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(gen): + if i == 0: + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter('../temp/result.avi', + cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for pl, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1)) + f[y1:y2, x1:x2] = pl + out.write(f) + + out.release() + + vid = os.path.join(args.results_dir, '{}.mp4'.format(idx)) + + command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio, + '../temp/result.avi', vid) + subprocess.call(command, shell=True) + +if __name__ == '__main__': + main() diff --git a/wav2lip/evaluation/real_videos_inference.py b/wav2lip/evaluation/real_videos_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9fb15ef342bf03caf77802ddf5b887bab3fb34 --- /dev/null +++ b/wav2lip/evaluation/real_videos_inference.py @@ -0,0 +1,305 @@ +from os import listdir, path +import numpy as np +import scipy, cv2, os, sys, argparse +import dlib, json, subprocess +from tqdm import tqdm +from glob import glob +import torch + +sys.path.append('../') +import audio +import face_detection +from models import Wav2Lip + +parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set') + +parser.add_argument('--mode', type=str, + help='random | dubbed | tts', required=True) + +parser.add_argument('--filelist', type=str, + help='Filepath of filelist file to read', default=None) + +parser.add_argument('--results_dir', type=str, help='Folder to save all results into', + required=True) +parser.add_argument('--data_root', type=str, required=True) +parser.add_argument('--checkpoint_path', type=str, + help='Name of saved checkpoint to load weights from', required=True) +parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], + help='Padding (top, bottom, left, right)') + +parser.add_argument('--face_det_batch_size', type=int, + help='Single GPU batch size for face detection', default=16) + +parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128) +parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180) +parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480) +parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720) +# parser.add_argument('--resize_factor', default=1, type=int) + +args = parser.parse_args() +args.img_size = 96 + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + +def rescale_frames(images): + rect = detector.get_detections_for_batch(np.array([images[0]]))[0] + if rect is None: + raise ValueError('Face not detected!') + h, w = images[0].shape[:-1] + + x1, y1, x2, y2 = rect + + face_size = max(np.abs(y1 - y2), np.abs(x1 - x2)) + + diff = np.abs(face_size - args.face_res) + for factor in range(2, 16): + downsampled_res = face_size // factor + if min(h//factor, w//factor) < args.min_frame_res: break + if np.abs(downsampled_res - args.face_res) >= diff: break + + factor -= 1 + if factor == 1: return images + + return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images] + + +def face_detect(images): + batch_size = args.face_det_batch_size + images = rescale_frames(images) + + while 1: + predictions = [] + try: + for i in range(0, len(images), batch_size): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = args.pads + for rect, image in zip(predictions, images): + if rect is None: + raise ValueError('Face not detected!') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = get_smoothened_boxes(np.array(results), T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + return results, images + +def datagen(frames, face_det_results, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + for i, m in enumerate(mels): + if i >= len(frames): raise ValueError('Equal or less lengths only') + + frame_to_save = frames[i].copy() + face, coords, valid_frame = face_det_results[i].copy() + if not valid_frame: + continue + + face = cv2.resize(face, (args.img_size, args.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= args.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + +def increase_frames(frames, l): + ## evenly duplicating frames to increase length of video + while len(frames) < l: + dup_every = float(l) / len(frames) + + final_frames = [] + next_duplicate = 0. + + for i, f in enumerate(frames): + final_frames.append(f) + + if int(np.ceil(next_duplicate)) == i: + final_frames.append(f) + + next_duplicate += dup_every + + frames = final_frames + + return frames[:l] + +mel_step_size = 16 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + +detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() + +model = load_model(args.checkpoint_path) + +def main(): + if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) + + if args.mode == 'dubbed': + files = listdir(args.data_root) + lines = ['{} {}'.format(f, f) for f in files] + + else: + assert args.filelist is not None + with open(args.filelist, 'r') as filelist: + lines = filelist.readlines() + + for idx, line in enumerate(tqdm(lines)): + video, audio_src = line.strip().split() + + audio_src = os.path.join(args.data_root, audio_src) + video = os.path.join(args.data_root, video) + + command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav') + subprocess.call(command, shell=True) + temp_audio = '../temp/temp.wav' + + wav = audio.load_wav(temp_audio, 16000) + mel = audio.melspectrogram(wav) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan!') + + video_stream = cv2.VideoCapture(video) + + fps = video_stream.get(cv2.CAP_PROP_FPS) + mel_idx_multiplier = 80./fps + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + + if min(frame.shape[:-1]) > args.max_frame_res: + h, w = frame.shape[:-1] + scale_factor = min(h, w) / float(args.max_frame_res) + h = int(h/scale_factor) + w = int(w/scale_factor) + + frame = cv2.resize(frame, (w, h)) + full_frames.append(frame) + + mel_chunks = [] + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 + + if len(full_frames) < len(mel_chunks): + if args.mode == 'tts': + full_frames = increase_frames(full_frames, len(mel_chunks)) + else: + raise ValueError('#Frames, audio length mismatch') + + else: + full_frames = full_frames[:len(mel_chunks)] + + try: + face_det_results, full_frames = face_detect(full_frames.copy()) + except ValueError as e: + continue + + batch_size = args.wav2lip_batch_size + gen = datagen(full_frames.copy(), face_det_results, mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(gen): + if i == 0: + frame_h, frame_w = full_frames[0].shape[:-1] + + out = cv2.VideoWriter('../temp/result.avi', + cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for pl, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1)) + f[y1:y2, x1:x2] = pl + out.write(f) + + out.release() + + vid = os.path.join(args.results_dir, '{}.mp4'.format(idx)) + command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav', + '../temp/result.avi', vid) + subprocess.call(command, shell=True) + + +if __name__ == '__main__': + main() diff --git a/wav2lip/evaluation/scores_LSE/SyncNetInstance_calc_scores.py b/wav2lip/evaluation/scores_LSE/SyncNetInstance_calc_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..64906e257bd1f521d8fadb93e877ba83da7764ce --- /dev/null +++ b/wav2lip/evaluation/scores_LSE/SyncNetInstance_calc_scores.py @@ -0,0 +1,210 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- +# Video 25 FPS, Audio 16000HZ + +import torch +import numpy +import time, pdb, argparse, subprocess, os, math, glob +import cv2 +import python_speech_features + +from scipy import signal +from scipy.io import wavfile +from SyncNetModel import * +from shutil import rmtree + + +# ==================== Get OFFSET ==================== + +def calc_pdist(feat1, feat2, vshift=10): + + win_size = vshift*2+1 + + feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) + + dists = [] + + for i in range(0,len(feat1)): + + dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) + + return dists + +# ==================== MAIN DEF ==================== + +class SyncNetInstance(torch.nn.Module): + + def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024): + super(SyncNetInstance, self).__init__(); + + self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda(); + + def evaluate(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Convert files + # ========== ========== + + if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): + rmtree(os.path.join(opt.tmp_dir,opt.reference)) + + os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) + + command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg'))) + output = subprocess.call(command, shell=True, stdout=None) + + command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))) + output = subprocess.call(command, shell=True, stdout=None) + + # ========== ========== + # Load video + # ========== ========== + + images = [] + + flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg')) + flist.sort() + + for fname in flist: + img_input = cv2.imread(fname) + img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE + images.append(img_input) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Load audio + # ========== ========== + + sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav')) + mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) + mfcc = numpy.stack([numpy.array(i) for i in mfcc]) + + cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) + cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) + + # ========== ========== + # Check audio and video input length + # ========== ========== + + #if (float(len(audio))/16000) != (float(len(images))/25) : + # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) + + min_length = min(len(images),math.floor(len(audio)/640)) + + # ========== ========== + # Generate video and audio feats + # ========== ========== + + lastframe = min_length-5 + im_feat = [] + cc_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lip(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + cc_in = torch.cat(cc_batch,0) + cc_out = self.__S__.forward_aud(cc_in.cuda()) + cc_feat.append(cc_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + cc_feat = torch.cat(cc_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + #print('Compute time %.3f sec.' % (time.time()-tS)) + + dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift) + mdist = torch.mean(torch.stack(dists,1),1) + + minval, minidx = torch.min(mdist,0) + + offset = opt.vshift-minidx + conf = torch.median(mdist) - minval + + fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) + # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) + fconf = torch.median(mdist).numpy() - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) + #print('Framewise conf: ') + #print(fconfm) + #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf)) + + dists_npy = numpy.array([ dist.numpy() for dist in dists ]) + return offset.numpy(), conf.numpy(), minval.numpy() + + def extract_feature(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Load video + # ========== ========== + cap = cv2.VideoCapture(videofile) + + frame_num = 1; + images = [] + while frame_num: + frame_num += 1 + ret, image = cap.read() + if ret == 0: + break + + images.append(image) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Generate video feats + # ========== ========== + + lastframe = len(images)-4 + im_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lipfeat(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + print('Compute time %.3f sec.' % (time.time()-tS)) + + return im_feat + + + def loadParameters(self, path): + loaded_state = torch.load(path, map_location=lambda storage, loc: storage); + + self_state = self.__S__.state_dict(); + + for name, param in loaded_state.items(): + + self_state[name].copy_(param); diff --git a/wav2lip/evaluation/scores_LSE/calculate_scores_LRS.py b/wav2lip/evaluation/scores_LSE/calculate_scores_LRS.py new file mode 100644 index 0000000000000000000000000000000000000000..eda02b8fbb7ac2f07d238b92d0879fb26c979394 --- /dev/null +++ b/wav2lip/evaluation/scores_LSE/calculate_scores_LRS.py @@ -0,0 +1,53 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess +import glob +import os +from tqdm import tqdm + +from SyncNetInstance_calc_scores import * + +# ==================== LOAD PARAMS ==================== + + +parser = argparse.ArgumentParser(description = "SyncNet"); + +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--data_root', type=str, required=True, help=''); +parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help=''); +parser.add_argument('--reference', type=str, default="demo", help=''); + +opt = parser.parse_args(); + + +# ==================== RUN EVALUATION ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +#print("Model %s loaded."%opt.initial_model); +path = os.path.join(opt.data_root, "*.mp4") + +all_videos = glob.glob(path) + +prog_bar = tqdm(range(len(all_videos))) +avg_confidence = 0. +avg_min_distance = 0. + + +for videofile_idx in prog_bar: + videofile = all_videos[videofile_idx] + offset, confidence, min_distance = s.evaluate(opt, videofile=videofile) + avg_confidence += confidence + avg_min_distance += min_distance + prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3))) + prog_bar.refresh() + +print ('Average Confidence: {}'.format(avg_confidence/len(all_videos))) +print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos))) + + + diff --git a/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.py b/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..09622584653dd40ce610afc8aef5765cdea16e68 --- /dev/null +++ b/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.py @@ -0,0 +1,45 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess, pickle, os, gzip, glob + +from SyncNetInstance_calc_scores import * + +# ==================== PARSE ARGUMENT ==================== + +parser = argparse.ArgumentParser(description = "SyncNet"); +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--data_dir', type=str, default='data/work', help=''); +parser.add_argument('--videofile', type=str, default='', help=''); +parser.add_argument('--reference', type=str, default='', help=''); +opt = parser.parse_args(); + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) + + +# ==================== LOAD MODEL AND FILE LIST ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +#print("Model %s loaded."%opt.initial_model); + +flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi')) +flist.sort() + +# ==================== GET OFFSETS ==================== + +dists = [] +for idx, fname in enumerate(flist): + offset, conf, dist = s.evaluate(opt,videofile=fname) + print (str(dist)+" "+str(conf)) + +# ==================== PRINT RESULTS TO FILE ==================== + +#with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil: +# pickle.dump(dists, fil) diff --git a/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.sh b/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.sh new file mode 100644 index 0000000000000000000000000000000000000000..4a45cd568d10bfeea9fc31255fcdf121d3f4e0e9 --- /dev/null +++ b/wav2lip/evaluation/scores_LSE/calculate_scores_real_videos.sh @@ -0,0 +1,8 @@ +rm all_scores.txt +yourfilenames=`ls $1` + +for eachfile in $yourfilenames +do + python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir + python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt +done diff --git a/wav2lip/evaluation/test_filelists/README.md b/wav2lip/evaluation/test_filelists/README.md new file mode 100644 index 0000000000000000000000000000000000000000..84c9acaedff77e229ce05bf92892b432cc302c35 --- /dev/null +++ b/wav2lip/evaluation/test_filelists/README.md @@ -0,0 +1,13 @@ +This folder contains the filelists for the new evaluation framework proposed in the paper. + +## Test filelists for LRS2, LRS3, and LRW. + +This folder contains three filelists, each containing a list of names of audio-video pairs from the test sets of LRS2, LRS3, and LRW. The LRS2 and LRW filelists are strictly "Copyright BBC" and can only be used for “non-commercial research by applicants who have an agreement with the BBC to access the Lip Reading in the Wild and/or Lip Reading Sentences in the Wild datasets”. Please follow this link for more details: [https://www.bbc.co.uk/rd/projects/lip-reading-datasets](https://www.bbc.co.uk/rd/projects/lip-reading-datasets). + + +## ReSynCED benchmark + +The sub-folder `ReSynCED` contains filelists for our own Real-world lip-Sync Evaluation Dataset (ReSyncED). + + +#### Instructions on how to use the above two filelists are available in the README of the parent folder. diff --git a/wav2lip/evaluation/test_filelists/ReSyncED/random_pairs.txt b/wav2lip/evaluation/test_filelists/ReSyncED/random_pairs.txt new file mode 100644 index 0000000000000000000000000000000000000000..ffe2c40e117a109a97f8215438464b60e976dc73 --- /dev/null +++ b/wav2lip/evaluation/test_filelists/ReSyncED/random_pairs.txt @@ -0,0 +1,160 @@ +sachin.mp4 emma_cropped.mp4 +sachin.mp4 mourinho.mp4 +sachin.mp4 elon.mp4 +sachin.mp4 messi2.mp4 +sachin.mp4 cr1.mp4 +sachin.mp4 sachin.mp4 +sachin.mp4 sg.mp4 +sachin.mp4 fergi.mp4 +sachin.mp4 spanish_lec1.mp4 +sachin.mp4 bush_small.mp4 +sachin.mp4 macca_cut.mp4 +sachin.mp4 ca_cropped.mp4 +sachin.mp4 lecun.mp4 +sachin.mp4 spanish_lec0.mp4 +srk.mp4 emma_cropped.mp4 +srk.mp4 mourinho.mp4 +srk.mp4 elon.mp4 +srk.mp4 messi2.mp4 +srk.mp4 cr1.mp4 +srk.mp4 srk.mp4 +srk.mp4 sachin.mp4 +srk.mp4 sg.mp4 +srk.mp4 fergi.mp4 +srk.mp4 spanish_lec1.mp4 +srk.mp4 bush_small.mp4 +srk.mp4 macca_cut.mp4 +srk.mp4 ca_cropped.mp4 +srk.mp4 guardiola.mp4 +srk.mp4 lecun.mp4 +srk.mp4 spanish_lec0.mp4 +cr1.mp4 emma_cropped.mp4 +cr1.mp4 elon.mp4 +cr1.mp4 messi2.mp4 +cr1.mp4 cr1.mp4 +cr1.mp4 spanish_lec1.mp4 +cr1.mp4 bush_small.mp4 +cr1.mp4 macca_cut.mp4 +cr1.mp4 ca_cropped.mp4 +cr1.mp4 lecun.mp4 +cr1.mp4 spanish_lec0.mp4 +macca_cut.mp4 emma_cropped.mp4 +macca_cut.mp4 elon.mp4 +macca_cut.mp4 messi2.mp4 +macca_cut.mp4 spanish_lec1.mp4 +macca_cut.mp4 macca_cut.mp4 +macca_cut.mp4 ca_cropped.mp4 +macca_cut.mp4 spanish_lec0.mp4 +lecun.mp4 emma_cropped.mp4 +lecun.mp4 elon.mp4 +lecun.mp4 messi2.mp4 +lecun.mp4 spanish_lec1.mp4 +lecun.mp4 macca_cut.mp4 +lecun.mp4 ca_cropped.mp4 +lecun.mp4 lecun.mp4 +lecun.mp4 spanish_lec0.mp4 +messi2.mp4 emma_cropped.mp4 +messi2.mp4 elon.mp4 +messi2.mp4 messi2.mp4 +messi2.mp4 spanish_lec1.mp4 +messi2.mp4 macca_cut.mp4 +messi2.mp4 ca_cropped.mp4 +messi2.mp4 spanish_lec0.mp4 +ca_cropped.mp4 emma_cropped.mp4 +ca_cropped.mp4 elon.mp4 +ca_cropped.mp4 spanish_lec1.mp4 +ca_cropped.mp4 ca_cropped.mp4 +ca_cropped.mp4 spanish_lec0.mp4 +spanish_lec1.mp4 spanish_lec1.mp4 +spanish_lec1.mp4 spanish_lec0.mp4 +elon.mp4 elon.mp4 +elon.mp4 spanish_lec1.mp4 +elon.mp4 spanish_lec0.mp4 +guardiola.mp4 emma_cropped.mp4 +guardiola.mp4 mourinho.mp4 +guardiola.mp4 elon.mp4 +guardiola.mp4 messi2.mp4 +guardiola.mp4 cr1.mp4 +guardiola.mp4 sachin.mp4 +guardiola.mp4 sg.mp4 +guardiola.mp4 fergi.mp4 +guardiola.mp4 spanish_lec1.mp4 +guardiola.mp4 bush_small.mp4 +guardiola.mp4 macca_cut.mp4 +guardiola.mp4 ca_cropped.mp4 +guardiola.mp4 guardiola.mp4 +guardiola.mp4 lecun.mp4 +guardiola.mp4 spanish_lec0.mp4 +fergi.mp4 emma_cropped.mp4 +fergi.mp4 mourinho.mp4 +fergi.mp4 elon.mp4 +fergi.mp4 messi2.mp4 +fergi.mp4 cr1.mp4 +fergi.mp4 sachin.mp4 +fergi.mp4 sg.mp4 +fergi.mp4 fergi.mp4 +fergi.mp4 spanish_lec1.mp4 +fergi.mp4 bush_small.mp4 +fergi.mp4 macca_cut.mp4 +fergi.mp4 ca_cropped.mp4 +fergi.mp4 lecun.mp4 +fergi.mp4 spanish_lec0.mp4 +spanish.mp4 emma_cropped.mp4 +spanish.mp4 spanish.mp4 +spanish.mp4 mourinho.mp4 +spanish.mp4 elon.mp4 +spanish.mp4 messi2.mp4 +spanish.mp4 cr1.mp4 +spanish.mp4 srk.mp4 +spanish.mp4 sachin.mp4 +spanish.mp4 sg.mp4 +spanish.mp4 fergi.mp4 +spanish.mp4 spanish_lec1.mp4 +spanish.mp4 bush_small.mp4 +spanish.mp4 macca_cut.mp4 +spanish.mp4 ca_cropped.mp4 +spanish.mp4 guardiola.mp4 +spanish.mp4 lecun.mp4 +spanish.mp4 spanish_lec0.mp4 +bush_small.mp4 emma_cropped.mp4 +bush_small.mp4 elon.mp4 +bush_small.mp4 messi2.mp4 +bush_small.mp4 spanish_lec1.mp4 +bush_small.mp4 bush_small.mp4 +bush_small.mp4 macca_cut.mp4 +bush_small.mp4 ca_cropped.mp4 +bush_small.mp4 lecun.mp4 +bush_small.mp4 spanish_lec0.mp4 +emma_cropped.mp4 emma_cropped.mp4 +emma_cropped.mp4 elon.mp4 +emma_cropped.mp4 spanish_lec1.mp4 +emma_cropped.mp4 spanish_lec0.mp4 +sg.mp4 emma_cropped.mp4 +sg.mp4 mourinho.mp4 +sg.mp4 elon.mp4 +sg.mp4 messi2.mp4 +sg.mp4 cr1.mp4 +sg.mp4 sachin.mp4 +sg.mp4 sg.mp4 +sg.mp4 fergi.mp4 +sg.mp4 spanish_lec1.mp4 +sg.mp4 bush_small.mp4 +sg.mp4 macca_cut.mp4 +sg.mp4 ca_cropped.mp4 +sg.mp4 lecun.mp4 +sg.mp4 spanish_lec0.mp4 +spanish_lec0.mp4 spanish_lec0.mp4 +mourinho.mp4 emma_cropped.mp4 +mourinho.mp4 mourinho.mp4 +mourinho.mp4 elon.mp4 +mourinho.mp4 messi2.mp4 +mourinho.mp4 cr1.mp4 +mourinho.mp4 sachin.mp4 +mourinho.mp4 sg.mp4 +mourinho.mp4 fergi.mp4 +mourinho.mp4 spanish_lec1.mp4 +mourinho.mp4 bush_small.mp4 +mourinho.mp4 macca_cut.mp4 +mourinho.mp4 ca_cropped.mp4 +mourinho.mp4 lecun.mp4 +mourinho.mp4 spanish_lec0.mp4 diff --git a/wav2lip/evaluation/test_filelists/ReSyncED/tts_pairs.txt b/wav2lip/evaluation/test_filelists/ReSyncED/tts_pairs.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7dc1a8c0b50ebb33ba0edf269ea9329933a10dc --- /dev/null +++ b/wav2lip/evaluation/test_filelists/ReSyncED/tts_pairs.txt @@ -0,0 +1,18 @@ +adam_1.mp4 andreng_optimization.wav +agad_2.mp4 agad_2.wav +agad_1.mp4 agad_1.wav +agad_3.mp4 agad_3.wav +rms_prop_1.mp4 rms_prop_tts.wav +tf_1.mp4 tf_1.wav +tf_2.mp4 tf_2.wav +andrew_ng_ai_business.mp4 andrewng_business_tts.wav +covid_autopsy_1.mp4 autopsy_tts.wav +news_1.mp4 news_tts.wav +andrew_ng_fund_1.mp4 andrewng_ai_fund.wav +covid_treatments_1.mp4 covid_tts.wav +pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav +pytorch_1.mp4 pytorch.wav +pkb_1.mp4 pkb_1.wav +ss_1.mp4 ss_1.wav +carlsen_1.mp4 carlsen_eng.wav +french.mp4 french.wav \ No newline at end of file diff --git a/wav2lip/face_detection/README.md b/wav2lip/face_detection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c073376e4eeda6d4b29cc31c50cb7e88ab42bb73 --- /dev/null +++ b/wav2lip/face_detection/README.md @@ -0,0 +1 @@ +The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. \ No newline at end of file diff --git a/wav2lip/face_detection/__init__.py b/wav2lip/face_detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4bae29fd5f85b41e4669302bd2603bc6924eddc7 --- /dev/null +++ b/wav2lip/face_detection/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__author__ = """Adrian Bulat""" +__email__ = 'adrian.bulat@nottingham.ac.uk' +__version__ = '1.0.1' + +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/wav2lip/face_detection/api.py b/wav2lip/face_detection/api.py new file mode 100644 index 0000000000000000000000000000000000000000..cb02d5252db5362b9985687a992e128a522e5b63 --- /dev/null +++ b/wav2lip/face_detection/api.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import torch +from torch.utils.model_zoo import load_url +from enum import Enum +import numpy as np +import cv2 +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +from .models import FAN, ResNetDepth +from .utils import * + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + # Get the face detector + face_detector_module = __import__('face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + + return results \ No newline at end of file diff --git a/wav2lip/face_detection/detection/__init__.py b/wav2lip/face_detection/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a6b0402dae864a3cc5dc2a90a412fd842a0efc7 --- /dev/null +++ b/wav2lip/face_detection/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector \ No newline at end of file diff --git a/wav2lip/face_detection/detection/core.py b/wav2lip/face_detection/detection/core.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8275e8e53143f66298f75f0517c234a68778cd --- /dev/null +++ b/wav2lip/face_detection/detection/core.py @@ -0,0 +1,130 @@ +import logging +import glob +from tqdm import tqdm +import numpy as np +import torch +import cv2 + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + + def __init__(self, device, verbose): + self.device = device + self.verbose = verbose + + if verbose: + if 'cpu' in device: + logger = logging.getLogger(__name__) + logger.warning("Detection running on CPU, this may be potentially slow.") + + if 'cpu' not in device and 'cuda' not in device: + if verbose: + logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) + raise ValueError + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Arguments: + path {string} -- a string containing a path that points to the folder containing the images + + Keyword Arguments: + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error("Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or torch.tensor to a numpy.ndarray + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] + elif torch.is_tensor(tensor_or_path): + # Call cpu in case its coming from cuda + return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/wav2lip/face_detection/detection/sfd/__init__.py b/wav2lip/face_detection/detection/sfd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a63ecd45658f22e66c171ada751fb33764d4559 --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector \ No newline at end of file diff --git a/wav2lip/face_detection/detection/sfd/bbox.py b/wav2lip/face_detection/detection/sfd/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd7222e5e5f78a51944cbeed3cccbacddc46bed --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/bbox.py @@ -0,0 +1,129 @@ +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + +def batch_decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes diff --git a/wav2lip/face_detection/detection/sfd/detect.py b/wav2lip/face_detection/detection/sfd/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..efef6273adf317bc17f3dd0f02423c0701ca218e --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/detect.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img, device): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + img = torch.from_numpy(img).float().to(device) + BB, CC, HH, WW = img.size() + with torch.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + +def batch_detect(net, imgs, device): + imgs = imgs - np.array([104, 117, 123]) + imgs = imgs.transpose(0, 3, 1, 2) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + imgs = torch.from_numpy(imgs).float().to(device) + BB, CC, HH, WW = imgs.size() + with torch.no_grad(): + olist = net(imgs) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[:, 1, hindex, windex] + loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) + variances = [0.1, 0.2] + box = batch_decode(loc, priors, variances) + box = box[:, 0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, BB, 5)) + + return bboxlist + +def flip_detect(net, img, device): + img = cv2.flip(img, 1) + b = detect(net, img, device) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/wav2lip/face_detection/detection/sfd/net_s3fd.py b/wav2lip/face_detection/detection/sfd/net_s3fd.py new file mode 100644 index 0000000000000000000000000000000000000000..fc64313c277ab594d0257585c70f147606693452 --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/net_s3fd.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/wav2lip/face_detection/detection/sfd/sfd_detector.py b/wav2lip/face_detection/detection/sfd/sfd_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbce15253251d403754ab4348f93ae85a6ba2fb --- /dev/null +++ b/wav2lip/face_detection/detection/sfd/sfd_detector.py @@ -0,0 +1,59 @@ +import os +import cv2 +from torch.utils.model_zoo import load_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', +} + + +class SFDDetector(FaceDetector): + def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): + super(SFDDetector, self).__init__(device, verbose) + + # Initialise the face detector + if not os.path.isfile(path_to_detector): + model_weights = load_url(models_urls['s3fd']) + else: + model_weights = torch.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_state_dict(model_weights) + self.face_detector.to(device) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image, device=self.device) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + def detect_from_batch(self, images): + bboxlists = batch_detect(self.face_detector, images, device=self.device) + keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] + bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] + bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] + + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/wav2lip/face_detection/models.py b/wav2lip/face_detection/models.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2dde32bdf72c25a4600e48efa73ffc0d4a3893 --- /dev/null +++ b/wav2lip/face_detection/models.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/wav2lip/face_detection/utils.py b/wav2lip/face_detection/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc4cf3e328efaa227cbcfdd969e1056688adad5 --- /dev/null +++ b/wav2lip/face_detection/utils.py @@ -0,0 +1,313 @@ +from __future__ import print_function +import os +import sys +import time +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + +def get_preds_fromhm_batch(hm, centers=None, scales=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the centers + and the scales is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + centers {torch.tensor} -- the centers of the bounding box (default: {None}) + scales {float} -- face scales (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if centers is not None and scales is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], centers[i], scales[i], hm.size(2), True) + + return preds, preds_orig + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path diff --git a/wav2lip/filelists/README.md b/wav2lip/filelists/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e7d7e7bb3b5adefc9fee84168693e978f129c6e6 --- /dev/null +++ b/wav2lip/filelists/README.md @@ -0,0 +1 @@ +Place LRS2 (and any other) filelists here for training. \ No newline at end of file diff --git a/wav2lip/hparams.py b/wav2lip/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..1c019046279f497e4eae3f839f683bc0b1193c6b --- /dev/null +++ b/wav2lip/hparams.py @@ -0,0 +1,101 @@ +from glob import glob +import os + +def get_image_list(data_root, split): + filelist = [] + + with open('filelists/{}.txt'.format(split)) as f: + for line in f: + line = line.strip() + if ' ' in line: line = line.split()[0] + filelist.append(os.path.join(data_root, line)) + + return filelist + +class HParams: + def __init__(self, **kwargs): + self.data = {} + + for key, value in kwargs.items(): + self.data[key] = value + + def __getattr__(self, key): + if key not in self.data: + raise AttributeError("'HParams' object has no attribute %s" % key) + return self.data[key] + + def set_hparam(self, key, value): + self.data[key] = value + + +# Default hyperparameters +hparams = HParams( + num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality + # network + rescale=True, # Whether to rescale audio prior to preprocessing + rescaling_max=0.9, # Rescaling value + + # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction + # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder + # Does not work if n_ffit is not multiple of hop_size!! + use_lws=False, + + n_fft=800, # Extra window size is filled with 0 paddings to match this parameter + hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) + win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) + + frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) + + # Mel and Linear spectrograms normalization/scaling and clipping + signal_normalization=True, + # Whether to normalize mel spectrograms to some predefined range (following below parameters) + allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True + symmetric_mels=True, + # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, + # faster and cleaner convergence) + max_abs_value=4., + # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not + # be too big to avoid gradient explosion, + # not too small for fast convergence) + # Contribution by @begeekmyfriend + # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude + # levels. Also allows for better G&L phase reconstruction) + preemphasize=True, # whether to apply filter + preemphasis=0.97, # filter coefficient. + + # Limits + min_level_db=-100, + ref_level_db=20, + fmin=55, + # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To + # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmax=7600, # To be increased/reduced depending on data. + + ###################### Our training parameters ################################# + img_size=96, + fps=25, + + batch_size=16, + initial_learning_rate=1e-4, + nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs + num_workers=16, + checkpoint_interval=3000, + eval_interval=3000, + save_optimizer_state=True, + + syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. + syncnet_batch_size=64, + syncnet_lr=1e-4, + syncnet_eval_interval=10000, + syncnet_checkpoint_interval=10000, + + disc_wt=0.07, + disc_initial_learning_rate=1e-4, +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] + return "Hyperparameters:\n" + "\n".join(hp) diff --git a/wav2lip/hq_wav2lip_train.py b/wav2lip/hq_wav2lip_train.py new file mode 100644 index 0000000000000000000000000000000000000000..c384ad9d80c82fd6bff51ea395eeefe93e1e0997 --- /dev/null +++ b/wav2lip/hq_wav2lip_train.py @@ -0,0 +1,443 @@ +from os.path import dirname, join, basename, isfile +from tqdm import tqdm + +from models import SyncNet_color as SyncNet +from models import Wav2Lip, Wav2Lip_disc_qual +import audio + +import torch +from torch import nn +from torch.nn import functional as F +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +import numpy as np + +from glob import glob + +import os, random, cv2, argparse +from hparams import hparams, get_image_list + +parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator') + +parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str) + +parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) +parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str) + +parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str) +parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str) + +args = parser.parse_args() + + +global_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +print('use_cuda: {}'.format(use_cuda)) + +syncnet_T = 5 +syncnet_mel_step_size = 16 + +class Dataset(object): + def __init__(self, split): + self.all_videos = get_image_list(args.data_root, split) + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + syncnet_T): + frame = join(vidname, '{}.jpg'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def read_window(self, window_fnames): + if window_fnames is None: return None + window = [] + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + return None + try: + img = cv2.resize(img, (hparams.img_size, hparams.img_size)) + except Exception as e: + return None + + window.append(img) + + return window + + def crop_audio_window(self, spec, start_frame): + if type(start_frame) == int: + start_frame_num = start_frame + else: + start_frame_num = self.get_frame_id(start_frame) + start_idx = int(80. * (start_frame_num / float(hparams.fps))) + + end_idx = start_idx + syncnet_mel_step_size + + return spec[start_idx : end_idx, :] + + def get_segmented_mels(self, spec, start_frame): + mels = [] + assert syncnet_T == 5 + start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing + if start_frame_num - 2 < 0: return None + for i in range(start_frame_num, start_frame_num + syncnet_T): + m = self.crop_audio_window(spec, i - 2) + if m.shape[0] != syncnet_mel_step_size: + return None + mels.append(m.T) + + mels = np.asarray(mels) + + return mels + + def prepare_window(self, window): + # 3 x T x H x W + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + + return x + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + vidname = self.all_videos[idx] + img_names = list(glob(join(vidname, '*.jpg'))) + if len(img_names) <= 3 * syncnet_T: + continue + + img_name = random.choice(img_names) + wrong_img_name = random.choice(img_names) + while wrong_img_name == img_name: + wrong_img_name = random.choice(img_names) + + window_fnames = self.get_window(img_name) + wrong_window_fnames = self.get_window(wrong_img_name) + if window_fnames is None or wrong_window_fnames is None: + continue + + window = self.read_window(window_fnames) + if window is None: + continue + + wrong_window = self.read_window(wrong_window_fnames) + if wrong_window is None: + continue + + try: + wavpath = join(vidname, "audio.wav") + wav = audio.load_wav(wavpath, hparams.sample_rate) + + orig_mel = audio.melspectrogram(wav).T + except Exception as e: + continue + + mel = self.crop_audio_window(orig_mel.copy(), img_name) + + if (mel.shape[0] != syncnet_mel_step_size): + continue + + indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) + if indiv_mels is None: continue + + window = self.prepare_window(window) + y = window.copy() + window[:, :, window.shape[2]//2:] = 0. + + wrong_window = self.prepare_window(wrong_window) + x = np.concatenate([window, wrong_window], axis=0) + + x = torch.FloatTensor(x) + mel = torch.FloatTensor(mel.T).unsqueeze(0) + indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) + y = torch.FloatTensor(y) + return x, indiv_mels, mel, y + +def save_sample_images(x, g, gt, global_step, checkpoint_dir): + x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + + refs, inps = x[..., 3:], x[..., :3] + folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step)) + if not os.path.exists(folder): os.mkdir(folder) + collage = np.concatenate((refs, inps, g, gt), axis=-2) + for batch_idx, c in enumerate(collage): + for t in range(len(c)): + cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t]) + +logloss = nn.BCELoss() +def cosine_loss(a, v, y): + d = nn.functional.cosine_similarity(a, v) + loss = logloss(d.unsqueeze(1), y) + + return loss + +device = torch.device("cuda" if use_cuda else "cpu") +syncnet = SyncNet().to(device) +for p in syncnet.parameters(): + p.requires_grad = False + +recon_loss = nn.L1Loss() +def get_sync_loss(mel, g): + g = g[:, :, :, g.size(3)//2:] + g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1) + # B, 3 * T, H//2, W + a, v = syncnet(mel, g) + y = torch.ones(g.size(0), 1).float().to(device) + return cosine_loss(a, v, y) + +def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, + checkpoint_dir=None, checkpoint_interval=None, nepochs=None): + global global_step, global_epoch + resumed_step = global_step + + while global_epoch < nepochs: + print('Starting Epoch: {}'.format(global_epoch)) + running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0. + running_disc_real_loss, running_disc_fake_loss = 0., 0. + prog_bar = tqdm(enumerate(train_data_loader)) + for step, (x, indiv_mels, mel, gt) in prog_bar: + disc.train() + model.train() + + x = x.to(device) + mel = mel.to(device) + indiv_mels = indiv_mels.to(device) + gt = gt.to(device) + + ### Train generator now. Remove ALL grads. + optimizer.zero_grad() + disc_optimizer.zero_grad() + + g = model(indiv_mels, x) + + if hparams.syncnet_wt > 0.: + sync_loss = get_sync_loss(mel, g) + else: + sync_loss = 0. + + if hparams.disc_wt > 0.: + perceptual_loss = disc.perceptual_forward(g) + else: + perceptual_loss = 0. + + l1loss = recon_loss(g, gt) + + loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \ + (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss + + loss.backward() + optimizer.step() + + ### Remove all gradients before Training disc + disc_optimizer.zero_grad() + + pred = disc(gt) + disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device)) + disc_real_loss.backward() + + pred = disc(g.detach()) + disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device)) + disc_fake_loss.backward() + + disc_optimizer.step() + + running_disc_real_loss += disc_real_loss.item() + running_disc_fake_loss += disc_fake_loss.item() + + if global_step % checkpoint_interval == 0: + save_sample_images(x, g, gt, global_step, checkpoint_dir) + + # Logs + global_step += 1 + cur_session_steps = global_step - resumed_step + + running_l1_loss += l1loss.item() + if hparams.syncnet_wt > 0.: + running_sync_loss += sync_loss.item() + else: + running_sync_loss += 0. + + if hparams.disc_wt > 0.: + running_perceptual_loss += perceptual_loss.item() + else: + running_perceptual_loss += 0. + + if global_step == 1 or global_step % checkpoint_interval == 0: + save_checkpoint( + model, optimizer, global_step, checkpoint_dir, global_epoch) + save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_') + + + if global_step % hparams.eval_interval == 0: + with torch.no_grad(): + average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc) + + if average_sync_loss < .75: + hparams.set_hparam('syncnet_wt', 0.03) + + prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1), + running_sync_loss / (step + 1), + running_perceptual_loss / (step + 1), + running_disc_fake_loss / (step + 1), + running_disc_real_loss / (step + 1))) + + global_epoch += 1 + +def eval_model(test_data_loader, global_step, device, model, disc): + eval_steps = 300 + print('Evaluating for {} steps'.format(eval_steps)) + running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], [] + while 1: + for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)): + model.eval() + disc.eval() + + x = x.to(device) + mel = mel.to(device) + indiv_mels = indiv_mels.to(device) + gt = gt.to(device) + + pred = disc(gt) + disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device)) + + g = model(indiv_mels, x) + pred = disc(g) + disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device)) + + running_disc_real_loss.append(disc_real_loss.item()) + running_disc_fake_loss.append(disc_fake_loss.item()) + + sync_loss = get_sync_loss(mel, g) + + if hparams.disc_wt > 0.: + perceptual_loss = disc.perceptual_forward(g) + else: + perceptual_loss = 0. + + l1loss = recon_loss(g, gt) + + loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \ + (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss + + running_l1_loss.append(l1loss.item()) + running_sync_loss.append(sync_loss.item()) + + if hparams.disc_wt > 0.: + running_perceptual_loss.append(perceptual_loss.item()) + else: + running_perceptual_loss.append(0.) + + if step > eval_steps: break + + print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss), + sum(running_sync_loss) / len(running_sync_loss), + sum(running_perceptual_loss) / len(running_perceptual_loss), + sum(running_disc_fake_loss) / len(running_disc_fake_loss), + sum(running_disc_real_loss) / len(running_disc_real_loss))) + return sum(running_sync_loss) / len(running_sync_loss) + + +def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''): + checkpoint_path = join( + checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): + global global_step + global global_epoch + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + if overwrite_global_states: + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + + return model + +if __name__ == "__main__": + checkpoint_dir = args.checkpoint_dir + + # Dataset and Dataloader setup + train_dataset = Dataset('train') + test_dataset = Dataset('val') + + train_data_loader = data_utils.DataLoader( + train_dataset, batch_size=hparams.batch_size, shuffle=True, + num_workers=hparams.num_workers) + + test_data_loader = data_utils.DataLoader( + test_dataset, batch_size=hparams.batch_size, + num_workers=4) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = Wav2Lip().to(device) + disc = Wav2Lip_disc_qual().to(device) + + print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) + print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad))) + + optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], + lr=hparams.initial_learning_rate, betas=(0.5, 0.999)) + disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad], + lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999)) + + if args.checkpoint_path is not None: + load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False) + + if args.disc_checkpoint_path is not None: + load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer, + reset_optimizer=False, overwrite_global_states=False) + + load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, + overwrite_global_states=False) + + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + # Train! + train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, + checkpoint_dir=checkpoint_dir, + checkpoint_interval=hparams.checkpoint_interval, + nepochs=hparams.nepochs) diff --git a/wav2lip/inference.py b/wav2lip/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..90692521e95d51c032c9dc0f84901aa4ff33fcfa --- /dev/null +++ b/wav2lip/inference.py @@ -0,0 +1,280 @@ +from os import listdir, path +import numpy as np +import scipy, cv2, os, sys, argparse, audio +import json, subprocess, random, string +from tqdm import tqdm +from glob import glob +import torch, face_detection +from models import Wav2Lip +import platform + +parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') + +parser.add_argument('--checkpoint_path', type=str, + help='Name of saved checkpoint to load weights from', required=True) + +parser.add_argument('--face', type=str, + help='Filepath of video/image that contains faces to use', required=True) +parser.add_argument('--audio', type=str, + help='Filepath of video/audio file to use as raw audio source', required=True) +parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', + default='results/result_voice.mp4') + +parser.add_argument('--static', type=bool, + help='If True, then use only first video frame for inference', default=False) +parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', + default=25., required=False) + +parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], + help='Padding (top, bottom, left, right). Please adjust to include chin at least') + +parser.add_argument('--face_det_batch_size', type=int, + help='Batch size for face detection', default=16) +parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128) + +parser.add_argument('--resize_factor', default=1, type=int, + help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') + +parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], + help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' + 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') + +parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], + help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' + 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') + +parser.add_argument('--rotate', default=False, action='store_true', + help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' + 'Use if you get a flipped result, despite feeding a normal looking video') + +parser.add_argument('--nosmooth', default=False, action='store_true', + help='Prevent smoothing face detections over a short temporal window') + +args = parser.parse_args() +args.img_size = 96 + +if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: + args.static = True + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + +def face_detect(images): + detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + + batch_size = args.face_det_batch_size + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = args.pads + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + +def datagen(frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if args.box[0] == -1: + if not args.static: + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = args.box + face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + + for i, m in enumerate(mels): + idx = 0 if args.static else i%len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (args.img_size, args.img_size)) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= args.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, args.img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + +mel_step_size = 16 +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + + model = model.to(device) + return model.eval() + +def main(): + if not os.path.isfile(args.face): + raise ValueError('--face argument must be a valid path to video/image file') + + elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: + full_frames = [cv2.imread(args.face)] + fps = args.fps + + else: + video_stream = cv2.VideoCapture(args.face) + fps = video_stream.get(cv2.CAP_PROP_FPS) + + print('Reading video frames...') + + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + if args.resize_factor > 1: + frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) + + if args.rotate: + frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) + + y1, y2, x1, x2 = args.crop + if x2 == -1: x2 = frame.shape[1] + if y2 == -1: y2 = frame.shape[0] + + frame = frame[y1:y2, x1:x2] + + full_frames.append(frame) + + print ("Number of frames available for inference: "+str(len(full_frames))) + + if not args.audio.endswith('.wav'): + print('Extracting raw audio...') + command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav') + + subprocess.call(command, shell=True) + args.audio = 'temp/temp.wav' + + wav = audio.load_wav(args.audio, 16000) + mel = audio.melspectrogram(wav) + print(mel.shape) + + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_chunks = [] + mel_idx_multiplier = 80./fps + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + break + mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) + i += 1 + + print("Length of mel chunks: {}".format(len(mel_chunks))) + + full_frames = full_frames[:len(mel_chunks)] + + batch_size = args.wav2lip_batch_size + gen = datagen(full_frames.copy(), mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, + total=int(np.ceil(float(len(mel_chunks))/batch_size)))): + if i == 0: + model = load_model(args.checkpoint_path) + print ("Model loaded") + + frame_h, frame_w = full_frames[0].shape[:-1] + out = cv2.VideoWriter('temp/result.avi', + cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + out.write(f) + + out.release() + + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile) + subprocess.call(command, shell=platform.system() != 'Windows') + +if __name__ == '__main__': + main() diff --git a/wav2lip/models/__init__.py b/wav2lip/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4374370494b65f10b76c70a2d4f731c238cfa54c --- /dev/null +++ b/wav2lip/models/__init__.py @@ -0,0 +1,2 @@ +from .wav2lip import Wav2Lip, Wav2Lip_disc_qual +from .syncnet import SyncNet_color \ No newline at end of file diff --git a/wav2lip/models/conv.py b/wav2lip/models/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ed83da00cb199e027ef217fd360352d91a7891ff --- /dev/null +++ b/wav2lip/models/conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(0.01, inplace=True) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/wav2lip/models/syncnet.py b/wav2lip/models/syncnet.py new file mode 100644 index 0000000000000000000000000000000000000000..e773cdca675236745a379a776b7c07d7d353f590 --- /dev/null +++ b/wav2lip/models/syncnet.py @@ -0,0 +1,66 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .conv import Conv2d + +class SyncNet_color(nn.Module): + def __init__(self): + super(SyncNet_color, self).__init__() + + self.face_encoder = nn.Sequential( + Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), + + Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + face_embedding = face_embedding.view(face_embedding.size(0), -1) + + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + face_embedding = F.normalize(face_embedding, p=2, dim=1) + + + return audio_embedding, face_embedding diff --git a/wav2lip/models/wav2lip.py b/wav2lip/models/wav2lip.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5d6919169ec497f0f0815184f5db8ba9108fbd --- /dev/null +++ b/wav2lip/models/wav2lip.py @@ -0,0 +1,184 @@ +import torch +from torch import nn +from torch.nn import functional as F +import math + +from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d + +class Wav2Lip(nn.Module): + def __init__(self): + super(Wav2Lip, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 + + nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 + + nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + + def forward(self, audio_sequences, face_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs + +class Wav2Lip_disc_qual(nn.Module): + def __init__(self): + super(Wav2Lip_disc_qual, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 + + nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 + nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 + nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 + nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + self.label_noise = .0 + + def get_lower_half(self, face_sequences): + return face_sequences[:, :, face_sequences.size(2)//2:] + + def to_2d(self, face_sequences): + B = face_sequences.size(0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + return face_sequences + + def perceptual_forward(self, false_face_sequences): + false_face_sequences = self.to_2d(false_face_sequences) + false_face_sequences = self.get_lower_half(false_face_sequences) + + false_feats = false_face_sequences + for f in self.face_encoder_blocks: + false_feats = f(false_feats) + + false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), + torch.ones((len(false_feats), 1)).cuda()) + + return false_pred_loss + + def forward(self, face_sequences): + face_sequences = self.to_2d(face_sequences) + face_sequences = self.get_lower_half(face_sequences) + + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + + return self.binary_pred(x).view(len(x), -1) diff --git a/wav2lip/preprocess.py b/wav2lip/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5322012ac60e91fefa47338d0e253c3f912ab7f2 --- /dev/null +++ b/wav2lip/preprocess.py @@ -0,0 +1,113 @@ +import sys + +if sys.version_info[0] < 3 and sys.version_info[1] < 2: + raise Exception("Must be using >= Python 3.2") + +from os import listdir, path + +if not path.isfile('face_detection/detection/sfd/s3fd.pth'): + raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \ + before running this script!') + +import multiprocessing as mp +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +import argparse, os, cv2, traceback, subprocess +from tqdm import tqdm +from glob import glob +import audio +from hparams import hparams as hp + +import face_detection + +parser = argparse.ArgumentParser() + +parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int) +parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int) +parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True) +parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True) + +args = parser.parse_args() + +fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, + device='cuda:{}'.format(id)) for id in range(args.ngpu)] + +template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' +# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}' + +def process_video_file(vfile, args, gpu_id): + video_stream = cv2.VideoCapture(vfile) + + frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + frames.append(frame) + + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = path.join(args.preprocessed_root, dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)] + + i = -1 + for fb in batches: + preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb)) + + for j, f in enumerate(preds): + i += 1 + if f is None: + continue + + x1, y1, x2, y2 = f + cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2]) + +def process_audio_file(vfile, args): + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = path.join(args.preprocessed_root, dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + wavpath = path.join(fulldir, 'audio.wav') + + command = template.format(vfile, wavpath) + subprocess.call(command, shell=True) + + +def mp_handler(job): + vfile, args, gpu_id = job + try: + process_video_file(vfile, args, gpu_id) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + +def main(args): + print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu)) + + filelist = glob(path.join(args.data_root, '*/*.mp4')) + + jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)] + p = ThreadPoolExecutor(args.ngpu) + futures = [p.submit(mp_handler, j) for j in jobs] + _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] + + print('Dumping audios...') + + for vfile in tqdm(filelist): + try: + process_audio_file(vfile, args) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + +if __name__ == '__main__': + main(args) \ No newline at end of file diff --git a/wav2lip/requirements.txt b/wav2lip/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bfd428ab990656291fdd3b8b44070848675b7b07 --- /dev/null +++ b/wav2lip/requirements.txt @@ -0,0 +1,8 @@ +librosa==0.7.0 +numpy==1.17.1 +opencv-contrib-python>=4.2.0.34 +opencv-python==4.1.0.25 +torch==1.1.0 +torchvision==0.3.0 +tqdm==4.45.0 +numba==0.48 diff --git a/wav2lip/results/README.md b/wav2lip/results/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b1bbfd53fded37aefe0f4fc97adc8de343341b7a --- /dev/null +++ b/wav2lip/results/README.md @@ -0,0 +1 @@ +Generated results will be placed in this folder by default. \ No newline at end of file diff --git a/wav2lip/temp/README.md b/wav2lip/temp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..04c910499300fa8dc05c317d7d30cb29f31ff836 --- /dev/null +++ b/wav2lip/temp/README.md @@ -0,0 +1 @@ +Temporary files at the time of inference/testing will be saved here. You can ignore them. \ No newline at end of file diff --git a/wav2lip/wav2lip_train.py b/wav2lip/wav2lip_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0811808af55464a803be1e268be33f1b8a31a9 --- /dev/null +++ b/wav2lip/wav2lip_train.py @@ -0,0 +1,374 @@ +from os.path import dirname, join, basename, isfile +from tqdm import tqdm + +from models import SyncNet_color as SyncNet +from models import Wav2Lip as Wav2Lip +import audio + +import torch +from torch import nn +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +import numpy as np + +from glob import glob + +import os, random, cv2, argparse +from hparams import hparams, get_image_list + +parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator') + +parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str) + +parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) +parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str) + +parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str) + +args = parser.parse_args() + + +global_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +print('use_cuda: {}'.format(use_cuda)) + +syncnet_T = 5 +syncnet_mel_step_size = 16 + +class Dataset(object): + def __init__(self, split): + self.all_videos = get_image_list(args.data_root, split) + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + syncnet_T): + frame = join(vidname, '{}.jpg'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def read_window(self, window_fnames): + if window_fnames is None: return None + window = [] + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + return None + try: + img = cv2.resize(img, (hparams.img_size, hparams.img_size)) + except Exception as e: + return None + + window.append(img) + + return window + + def crop_audio_window(self, spec, start_frame): + if type(start_frame) == int: + start_frame_num = start_frame + else: + start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing + start_idx = int(80. * (start_frame_num / float(hparams.fps))) + + end_idx = start_idx + syncnet_mel_step_size + + return spec[start_idx : end_idx, :] + + def get_segmented_mels(self, spec, start_frame): + mels = [] + assert syncnet_T == 5 + start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing + if start_frame_num - 2 < 0: return None + for i in range(start_frame_num, start_frame_num + syncnet_T): + m = self.crop_audio_window(spec, i - 2) + if m.shape[0] != syncnet_mel_step_size: + return None + mels.append(m.T) + + mels = np.asarray(mels) + + return mels + + def prepare_window(self, window): + # 3 x T x H x W + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + + return x + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + vidname = self.all_videos[idx] + img_names = list(glob(join(vidname, '*.jpg'))) + if len(img_names) <= 3 * syncnet_T: + continue + + img_name = random.choice(img_names) + wrong_img_name = random.choice(img_names) + while wrong_img_name == img_name: + wrong_img_name = random.choice(img_names) + + window_fnames = self.get_window(img_name) + wrong_window_fnames = self.get_window(wrong_img_name) + if window_fnames is None or wrong_window_fnames is None: + continue + + window = self.read_window(window_fnames) + if window is None: + continue + + wrong_window = self.read_window(wrong_window_fnames) + if wrong_window is None: + continue + + try: + wavpath = join(vidname, "audio.wav") + wav = audio.load_wav(wavpath, hparams.sample_rate) + + orig_mel = audio.melspectrogram(wav).T + except Exception as e: + continue + + mel = self.crop_audio_window(orig_mel.copy(), img_name) + + if (mel.shape[0] != syncnet_mel_step_size): + continue + + indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) + if indiv_mels is None: continue + + window = self.prepare_window(window) + y = window.copy() + window[:, :, window.shape[2]//2:] = 0. + + wrong_window = self.prepare_window(wrong_window) + x = np.concatenate([window, wrong_window], axis=0) + + x = torch.FloatTensor(x) + mel = torch.FloatTensor(mel.T).unsqueeze(0) + indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) + y = torch.FloatTensor(y) + return x, indiv_mels, mel, y + +def save_sample_images(x, g, gt, global_step, checkpoint_dir): + x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) + + refs, inps = x[..., 3:], x[..., :3] + folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step)) + if not os.path.exists(folder): os.mkdir(folder) + collage = np.concatenate((refs, inps, g, gt), axis=-2) + for batch_idx, c in enumerate(collage): + for t in range(len(c)): + cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t]) + +logloss = nn.BCELoss() +def cosine_loss(a, v, y): + d = nn.functional.cosine_similarity(a, v) + loss = logloss(d.unsqueeze(1), y) + + return loss + +device = torch.device("cuda" if use_cuda else "cpu") +syncnet = SyncNet().to(device) +for p in syncnet.parameters(): + p.requires_grad = False + +recon_loss = nn.L1Loss() +def get_sync_loss(mel, g): + g = g[:, :, :, g.size(3)//2:] + g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1) + # B, 3 * T, H//2, W + a, v = syncnet(mel, g) + y = torch.ones(g.size(0), 1).float().to(device) + return cosine_loss(a, v, y) + +def train(device, model, train_data_loader, test_data_loader, optimizer, + checkpoint_dir=None, checkpoint_interval=None, nepochs=None): + + global global_step, global_epoch + resumed_step = global_step + + while global_epoch < nepochs: + print('Starting Epoch: {}'.format(global_epoch)) + running_sync_loss, running_l1_loss = 0., 0. + prog_bar = tqdm(enumerate(train_data_loader)) + for step, (x, indiv_mels, mel, gt) in prog_bar: + model.train() + optimizer.zero_grad() + + # Move data to CUDA device + x = x.to(device) + mel = mel.to(device) + indiv_mels = indiv_mels.to(device) + gt = gt.to(device) + + g = model(indiv_mels, x) + + if hparams.syncnet_wt > 0.: + sync_loss = get_sync_loss(mel, g) + else: + sync_loss = 0. + + l1loss = recon_loss(g, gt) + + loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss + loss.backward() + optimizer.step() + + if global_step % checkpoint_interval == 0: + save_sample_images(x, g, gt, global_step, checkpoint_dir) + + global_step += 1 + cur_session_steps = global_step - resumed_step + + running_l1_loss += l1loss.item() + if hparams.syncnet_wt > 0.: + running_sync_loss += sync_loss.item() + else: + running_sync_loss += 0. + + if global_step == 1 or global_step % checkpoint_interval == 0: + save_checkpoint( + model, optimizer, global_step, checkpoint_dir, global_epoch) + + if global_step == 1 or global_step % hparams.eval_interval == 0: + with torch.no_grad(): + average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir) + + if average_sync_loss < .75: + hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient + + prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1), + running_sync_loss / (step + 1))) + + global_epoch += 1 + + +def eval_model(test_data_loader, global_step, device, model, checkpoint_dir): + eval_steps = 700 + print('Evaluating for {} steps'.format(eval_steps)) + sync_losses, recon_losses = [], [] + step = 0 + while 1: + for x, indiv_mels, mel, gt in test_data_loader: + step += 1 + model.eval() + + # Move data to CUDA device + x = x.to(device) + gt = gt.to(device) + indiv_mels = indiv_mels.to(device) + mel = mel.to(device) + + g = model(indiv_mels, x) + + sync_loss = get_sync_loss(mel, g) + l1loss = recon_loss(g, gt) + + sync_losses.append(sync_loss.item()) + recon_losses.append(l1loss.item()) + + if step > eval_steps: + averaged_sync_loss = sum(sync_losses) / len(sync_losses) + averaged_recon_loss = sum(recon_losses) / len(recon_losses) + + print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss)) + + return averaged_sync_loss + +def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): + + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + +def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): + global global_step + global global_epoch + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + if overwrite_global_states: + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + + return model + +if __name__ == "__main__": + checkpoint_dir = args.checkpoint_dir + + # Dataset and Dataloader setup + train_dataset = Dataset('train') + test_dataset = Dataset('val') + + train_data_loader = data_utils.DataLoader( + train_dataset, batch_size=hparams.batch_size, shuffle=True, + num_workers=hparams.num_workers) + + test_data_loader = data_utils.DataLoader( + test_dataset, batch_size=hparams.batch_size, + num_workers=4) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = Wav2Lip().to(device) + print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) + + optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], + lr=hparams.initial_learning_rate) + + if args.checkpoint_path is not None: + load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False) + + load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False) + + if not os.path.exists(checkpoint_dir): + os.mkdir(checkpoint_dir) + + # Train! + train(device, model, train_data_loader, test_data_loader, optimizer, + checkpoint_dir=checkpoint_dir, + checkpoint_interval=hparams.checkpoint_interval, + nepochs=hparams.nepochs)