jhj0517 commited on
Commit
52a96bb
2 Parent(s): 2bf87b3 c387a24

Merge pull request #15 from jhj0517/feature/image-restoration

Browse files
README.md CHANGED
@@ -52,7 +52,10 @@ docker compose -f docker/docker-compose.yaml up
52
 
53
  Update the [`docker-compose.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/docker/docker-compose.yaml) to match your environment if you're not using an Nvidia GPU.
54
 
55
- ## ❤️ Citation and Thanks
 
 
 
56
  1. LivePortrait paper comes from
57
  ```bibtex
58
  @article{guo2024liveportrait,
@@ -65,8 +68,6 @@ Update the [`docker-compose.yaml`](https://github.com/jhj0517/AdvancedLivePortra
65
  2. The models are safetensors that have been converted by kijai. : https://github.com/kijai/ComfyUI-LivePortraitKJ
66
  3. [ultralytics](https://github.com/ultralytics/ultralytics) is used to detect the face.
67
  4. This WebUI is started from [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait), various facial expressions like AAA, EEE, Eyebrow, Wink are found by PowerHouseMan.
68
-
69
- ### 🌐 Translation
70
- Any PRs for language translation for [`translation.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/i18n/translation.yaml) would be greatly appreciated!
71
 
72
 
 
52
 
53
  Update the [`docker-compose.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/docker/docker-compose.yaml) to match your environment if you're not using an Nvidia GPU.
54
 
55
+ ### 🌐 Translation
56
+ Any PRs for language translation for [`translation.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/i18n/translation.yaml) would be greatly appreciated!
57
+
58
+ ## ❤️ Acknowledgement
59
  1. LivePortrait paper comes from
60
  ```bibtex
61
  @article{guo2024liveportrait,
 
68
  2. The models are safetensors that have been converted by kijai. : https://github.com/kijai/ComfyUI-LivePortraitKJ
69
  3. [ultralytics](https://github.com/ultralytics/ultralytics) is used to detect the face.
70
  4. This WebUI is started from [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait), various facial expressions like AAA, EEE, Eyebrow, Wink are found by PowerHouseMan.
71
+ 5. [RealESRGAN](https://github.com/xinntao/Real-ESRGAN) is used for image restoration.
 
 
72
 
73
 
app.py CHANGED
@@ -41,7 +41,9 @@ class App:
41
  gr.Slider(label=_("Sample Ratio"), minimum=-0.2, maximum=1.2, step=0.01, value=1, visible=False),
42
  gr.Dropdown(label=_("Sample Parts"), visible=False,
43
  choices=[part.value for part in SamplePart], value=SamplePart.ALL.value),
44
- gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2)
 
 
45
  ]
46
 
47
  @staticmethod
@@ -53,6 +55,8 @@ class App:
53
  gr.Slider(label=_("First frame eyes alignment factor"), minimum=0, maximum=1, step=0.01, value=1),
54
  gr.Slider(label=_("First frame mouth alignment factor"), minimum=0, maximum=1, step=0.01, value=1),
55
  gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2),
 
 
56
  ]
57
 
58
  def launch(self):
 
41
  gr.Slider(label=_("Sample Ratio"), minimum=-0.2, maximum=1.2, step=0.01, value=1, visible=False),
42
  gr.Dropdown(label=_("Sample Parts"), visible=False,
43
  choices=[part.value for part in SamplePart], value=SamplePart.ALL.value),
44
+ gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2),
45
+ gr.Checkbox(label=_("Enable Image Restoration"),
46
+ info=_("This enables image restoration with RealESRGAN but slows down the speed"), value=False)
47
  ]
48
 
49
  @staticmethod
 
55
  gr.Slider(label=_("First frame eyes alignment factor"), minimum=0, maximum=1, step=0.01, value=1),
56
  gr.Slider(label=_("First frame mouth alignment factor"), minimum=0, maximum=1, step=0.01, value=1),
57
  gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2),
58
+ gr.Checkbox(label=_("Enable Image Restoration"),
59
+ info=_("This enables image restoration with RealESRGAN but slows down the speed"), value=False)
60
  ]
61
 
62
  def launch(self):
i18n/translation.yaml CHANGED
@@ -32,6 +32,8 @@ en: # English
32
  First frame mouth alignment factor: First frame mouth alignment factor
33
  First frame eyes alignment factor: First frame eyes alignment factor
34
  Face Crop Factor: Face Crop Factor
 
 
35
 
36
  ko: # Korean
37
  Language: 언어
@@ -67,6 +69,8 @@ ko: # Korean
67
  First frame mouth alignment factor: 첫 프레임 입 반영 비율
68
  First frame eyes alignment factor: 첫 프레임 눈 반영 비율
69
  Face Crop Factor: 얼굴 크롭 비율
 
 
70
 
71
  ja: # Japanese
72
  Language: 言語
@@ -102,6 +106,8 @@ ja: # Japanese
102
  First frame mouth alignment factor: First frame mouth alignment factor
103
  First frame eyes alignment factor: First frame eyes alignment factor
104
  Face Crop Factor: Face Crop Factor
 
 
105
 
106
  es: # Spanish
107
  Language: Idioma
@@ -137,6 +143,8 @@ es: # Spanish
137
  First frame mouth alignment factor: First frame mouth alignment factor
138
  First frame eyes alignment factor: First frame eyes alignment factor
139
  Face Crop Factor: Face Crop Factor
 
 
140
 
141
  fr: # French
142
  Language: Langue
@@ -172,6 +180,8 @@ fr: # French
172
  First frame mouth alignment factor: First frame mouth alignment factor
173
  First frame eyes alignment factor: First frame eyes alignment factor
174
  Face Crop Factor: Face Crop Factor
 
 
175
 
176
  de: # German
177
  Language: Sprache
@@ -207,6 +217,8 @@ de: # German
207
  First frame mouth alignment factor: First frame mouth alignment factor
208
  First frame eyes alignment factor: First frame eyes alignment factor
209
  Face Crop Factor: Face Crop Factor
 
 
210
 
211
  zh: # Chinese
212
  Language: 语言
@@ -242,6 +254,8 @@ zh: # Chinese
242
  First frame mouth alignment factor: First frame mouth alignment factor
243
  First frame eyes alignment factor: First frame eyes alignment factor
244
  Face Crop Factor: Face Crop Factor
 
 
245
 
246
  uk: # Ukrainian
247
  Language: Мова
@@ -277,6 +291,8 @@ uk: # Ukrainian
277
  First frame mouth alignment factor: First frame mouth alignment factor
278
  First frame eyes alignment factor: First frame eyes alignment factor
279
  Face Crop Factor: Face Crop Factor
 
 
280
 
281
  ru: # Russian
282
  Language: Язык
@@ -312,6 +328,8 @@ ru: # Russian
312
  First frame mouth alignment factor: First frame mouth alignment factor
313
  First frame eyes alignment factor: First frame eyes alignment factor
314
  Face Crop Factor: Face Crop Factor
 
 
315
 
316
  tr: # Turkish
317
  Language: Dil
@@ -347,3 +365,5 @@ tr: # Turkish
347
  First frame mouth alignment factor: First frame mouth alignment factor
348
  First frame eyes alignment factor: First frame eyes alignment factor
349
  Face Crop Factor: Face Crop Factor
 
 
 
32
  First frame mouth alignment factor: First frame mouth alignment factor
33
  First frame eyes alignment factor: First frame eyes alignment factor
34
  Face Crop Factor: Face Crop Factor
35
+ Enable Image Restoration: Enable Image Restoration
36
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
37
 
38
  ko: # Korean
39
  Language: 언어
 
69
  First frame mouth alignment factor: 첫 프레임 입 반영 비율
70
  First frame eyes alignment factor: 첫 프레임 눈 반영 비율
71
  Face Crop Factor: 얼굴 크롭 비율
72
+ Enable Image Restoration: 화질 향상
73
+ This enables image restoration with RealESRGAN but slows down the speed: RealESRGAN 으로 화질을 향상 시킵니다. 속도는 느려집니다.
74
 
75
  ja: # Japanese
76
  Language: 言語
 
106
  First frame mouth alignment factor: First frame mouth alignment factor
107
  First frame eyes alignment factor: First frame eyes alignment factor
108
  Face Crop Factor: Face Crop Factor
109
+ Enable Image Restoration: Enable Image Restoration
110
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
111
 
112
  es: # Spanish
113
  Language: Idioma
 
143
  First frame mouth alignment factor: First frame mouth alignment factor
144
  First frame eyes alignment factor: First frame eyes alignment factor
145
  Face Crop Factor: Face Crop Factor
146
+ Enable Image Restoration: Enable Image Restoration
147
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
148
 
149
  fr: # French
150
  Language: Langue
 
180
  First frame mouth alignment factor: First frame mouth alignment factor
181
  First frame eyes alignment factor: First frame eyes alignment factor
182
  Face Crop Factor: Face Crop Factor
183
+ Enable Image Restoration: Enable Image Restoration
184
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
185
 
186
  de: # German
187
  Language: Sprache
 
217
  First frame mouth alignment factor: First frame mouth alignment factor
218
  First frame eyes alignment factor: First frame eyes alignment factor
219
  Face Crop Factor: Face Crop Factor
220
+ Enable Image Restoration: Enable Image Restoration
221
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
222
 
223
  zh: # Chinese
224
  Language: 语言
 
254
  First frame mouth alignment factor: First frame mouth alignment factor
255
  First frame eyes alignment factor: First frame eyes alignment factor
256
  Face Crop Factor: Face Crop Factor
257
+ Enable Image Restoration: Enable Image Restoration
258
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
259
 
260
  uk: # Ukrainian
261
  Language: Мова
 
291
  First frame mouth alignment factor: First frame mouth alignment factor
292
  First frame eyes alignment factor: First frame eyes alignment factor
293
  Face Crop Factor: Face Crop Factor
294
+ Enable Image Restoration: Enable Image Restoration
295
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
296
 
297
  ru: # Russian
298
  Language: Язык
 
328
  First frame mouth alignment factor: First frame mouth alignment factor
329
  First frame eyes alignment factor: First frame eyes alignment factor
330
  Face Crop Factor: Face Crop Factor
331
+ Enable Image Restoration: Enable Image Restoration
332
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
333
 
334
  tr: # Turkish
335
  Language: Dil
 
365
  First frame mouth alignment factor: First frame mouth alignment factor
366
  First frame eyes alignment factor: First frame eyes alignment factor
367
  Face Crop Factor: Face Crop Factor
368
+ Enable Image Restoration: Enable Image Restoration
369
+ This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed
modules/image_restoration/__init__.py ADDED
File without changes
modules/image_restoration/real_esrgan/__init__.py ADDED
File without changes
modules/image_restoration/real_esrgan/model_downloader.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.live_portrait.model_downloader import download_model
2
+
3
+ MODELS_REALESRGAN_URL = {
4
+ "realesr-general-x4v3": "https://huggingface.co/jhj0517/realesr-general-x4v3/resolve/main/realesr-general-x4v3.pth",
5
+ "RealESRGAN_x2": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth",
6
+ }
7
+
8
+ MODELS_REALESRGAN_SCALABILITY = {
9
+ "realesr-general-x4v3": [1, 2, 4],
10
+ "RealESRGAN_x2": [2]
11
+ }
12
+
13
+
14
+ def download_resrgan_model(file_path, url):
15
+ return download_model(file_path, url)
modules/image_restoration/real_esrgan/real_esrgan_inferencer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import gradio as gr
3
+ import torch
4
+ import cv2
5
+ from typing import Optional, Literal
6
+
7
+ from modules.utils.paths import *
8
+ from modules.utils.image_helper import save_image
9
+ from .model_downloader import download_resrgan_model, MODELS_REALESRGAN_URL, MODELS_REALESRGAN_SCALABILITY
10
+ from .wrapper.rrdb_net import RRDBNet
11
+ from .wrapper.real_esrganer import RealESRGANer
12
+ from .wrapper.srvgg_net_compact import SRVGGNetCompact
13
+
14
+
15
+ class RealESRGANInferencer:
16
+ def __init__(self,
17
+ model_dir: str = MODELS_REAL_ESRGAN_DIR,
18
+ output_dir: str = OUTPUTS_DIR):
19
+ self.model_dir = model_dir
20
+ self.output_dir = output_dir
21
+ self.device = self.get_device()
22
+ self.arc = None
23
+ self.model = None
24
+ self.face_enhancer = None
25
+
26
+ self.available_models = list(MODELS_REALESRGAN_URL.keys())
27
+ self.default_model = self.available_models[0]
28
+ self.model_config = {
29
+ "model_name": self.default_model,
30
+ "scale": 1,
31
+ "half_precision": True
32
+ }
33
+
34
+ def load_model(self,
35
+ model_name: Optional[str] = None,
36
+ scale: Literal[1, 2, 4] = 1,
37
+ half_precision: bool = True,
38
+ progress: gr.Progress = gr.Progress()):
39
+ model_config = {
40
+ "model_name": model_name,
41
+ "scale": scale,
42
+ "half_precision": half_precision
43
+ }
44
+ if model_config == self.model_config and self.model is not None:
45
+ return
46
+ else:
47
+ self.model_config = model_config
48
+
49
+ if model_name is None:
50
+ model_name = self.default_model
51
+
52
+ model_path = os.path.join(self.model_dir, model_name)
53
+ if not model_name.endswith(".pth"):
54
+ model_path += ".pth"
55
+
56
+ if not os.path.exists(model_path):
57
+ progress(0, f"Downloading RealESRGAN model to : {model_path}")
58
+ download_resrgan_model(model_path, MODELS_REALESRGAN_URL[model_name])
59
+
60
+ name, ext = os.path.splitext(model_name)
61
+ assert scale in MODELS_REALESRGAN_SCALABILITY[name]
62
+ if name == 'RealESRGAN_x2': # x4 RRDBNet model
63
+ arc = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
64
+ netscale = 4
65
+ else: # x4 VGG-style model (S size) : "realesr-general-x4v3"
66
+ arc = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
67
+ netscale = 4
68
+
69
+ self.model = RealESRGANer(
70
+ scale=netscale,
71
+ model_path=model_path,
72
+ model=arc,
73
+ half=half_precision,
74
+ device=torch.device(self.get_device())
75
+ )
76
+
77
+ def restore_image(self,
78
+ img_path: str,
79
+ model_name: Optional[str] = None,
80
+ scale: int = 1,
81
+ half_precision: Optional[bool] = None,
82
+ overwrite: bool = True):
83
+ model_config = {
84
+ "model_name": self.model_config["model_name"],
85
+ "scale": scale,
86
+ "half_precision": half_precision
87
+ }
88
+ half_precision = True if self.device == "cuda" else False
89
+
90
+ if self.model is None or self.model_config != model_config:
91
+ self.load_model(
92
+ model_name=self.default_model if model_name is None else model_name,
93
+ scale=scale,
94
+ half_precision=half_precision
95
+ )
96
+
97
+ try:
98
+ with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
99
+ output, img_mode = self.model.enhance(img_path, outscale=scale)
100
+ if img_mode == "RGB":
101
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
102
+
103
+ if overwrite:
104
+ output_path = img_path
105
+ else:
106
+ output_path = get_auto_incremental_file_path(self.output_dir, extension="png")
107
+
108
+ output_path = save_image(output, output_path=output_path)
109
+ return output_path
110
+ except Exception as e:
111
+ raise
112
+
113
+ @staticmethod
114
+ def get_device():
115
+ if torch.cuda.is_available():
116
+ return "cuda"
117
+ elif torch.backends.mps.is_available():
118
+ return "mps"
119
+ else:
120
+ return "cpu"
modules/image_restoration/real_esrgan/wrapper/__init__.py ADDED
File without changes
modules/image_restoration/real_esrgan/wrapper/real_esrganer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import queue
5
+ import threading
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class RealESRGANer():
11
+ """A helper class for upsampling images with RealESRGAN.
12
+
13
+ Args:
14
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
15
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
16
+ model (nn.Module): The defined network. Default: None.
17
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
18
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
19
+ 0 denotes for do not use tile. Default: 0.
20
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
21
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
22
+ half (float): Whether to use half precision during inference. Default: False.
23
+ """
24
+
25
+ def __init__(self,
26
+ scale,
27
+ model_path,
28
+ dni_weight=None,
29
+ model=None,
30
+ tile=0,
31
+ tile_pad=10,
32
+ pre_pad=10,
33
+ half=False,
34
+ device=None,
35
+ gpu_id=None):
36
+ self.scale = scale
37
+ self.tile_size = tile
38
+ self.tile_pad = tile_pad
39
+ self.pre_pad = pre_pad
40
+ self.mod_scale = None
41
+ self.half = half
42
+
43
+ # initialize model
44
+ if gpu_id:
45
+ self.device = torch.device(
46
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
47
+ else:
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
49
+
50
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
51
+
52
+ # prefer to use params_ema
53
+ if 'params_ema' in loadnet:
54
+ keyname = 'params_ema'
55
+ else:
56
+ keyname = 'params'
57
+ model.load_state_dict(loadnet[keyname], strict=True)
58
+
59
+ model.eval()
60
+ self.model = model.to(self.device)
61
+ if self.half:
62
+ self.model = self.model.half()
63
+
64
+ def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
65
+ """Deep network interpolation.
66
+
67
+ ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
68
+ """
69
+ net_a = torch.load(net_a, map_location=torch.device(loc))
70
+ net_b = torch.load(net_b, map_location=torch.device(loc))
71
+ for k, v_a in net_a[key].items():
72
+ net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
73
+ return net_a
74
+
75
+ def pre_process(self, img):
76
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
77
+ """
78
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
79
+ self.img = img.unsqueeze(0).to(self.device)
80
+ if self.half:
81
+ self.img = self.img.half()
82
+
83
+ # pre_pad
84
+ if self.pre_pad != 0:
85
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
86
+ # mod pad for divisible borders
87
+ if self.scale == 2:
88
+ self.mod_scale = 2
89
+ elif self.scale == 1:
90
+ self.mod_scale = 4
91
+ if self.mod_scale is not None:
92
+ self.mod_pad_h, self.mod_pad_w = 0, 0
93
+ _, _, h, w = self.img.size()
94
+ if (h % self.mod_scale != 0):
95
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
96
+ if (w % self.mod_scale != 0):
97
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
98
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
99
+
100
+ def process(self):
101
+ # model inference
102
+ self.output = self.model(self.img)
103
+
104
+ def tile_process(self):
105
+ """It will first crop input images to tiles, and then process each tile.
106
+ Finally, all the processed tiles are merged into one images.
107
+
108
+ Modified from: https://github.com/ata4/esrgan-launcher
109
+ """
110
+ batch, channel, height, width = self.img.shape
111
+ output_height = height * self.scale
112
+ output_width = width * self.scale
113
+ output_shape = (batch, channel, output_height, output_width)
114
+
115
+ # start with black image
116
+ self.output = self.img.new_zeros(output_shape)
117
+ tiles_x = math.ceil(width / self.tile_size)
118
+ tiles_y = math.ceil(height / self.tile_size)
119
+
120
+ # loop over all tiles
121
+ for y in range(tiles_y):
122
+ for x in range(tiles_x):
123
+ # extract tile from input image
124
+ ofs_x = x * self.tile_size
125
+ ofs_y = y * self.tile_size
126
+ # input tile area on total image
127
+ input_start_x = ofs_x
128
+ input_end_x = min(ofs_x + self.tile_size, width)
129
+ input_start_y = ofs_y
130
+ input_end_y = min(ofs_y + self.tile_size, height)
131
+
132
+ # input tile area on total image with padding
133
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
134
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
135
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
136
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
137
+
138
+ # input tile dimensions
139
+ input_tile_width = input_end_x - input_start_x
140
+ input_tile_height = input_end_y - input_start_y
141
+ tile_idx = y * tiles_x + x + 1
142
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
143
+
144
+ # upscale tile
145
+ try:
146
+ with torch.no_grad():
147
+ output_tile = self.model(input_tile)
148
+ except RuntimeError as error:
149
+ print('Error', error)
150
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
151
+
152
+ # output tile area on total image
153
+ output_start_x = input_start_x * self.scale
154
+ output_end_x = input_end_x * self.scale
155
+ output_start_y = input_start_y * self.scale
156
+ output_end_y = input_end_y * self.scale
157
+
158
+ # output tile area without padding
159
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
160
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
161
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
162
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
163
+
164
+ # put tile into output image
165
+ self.output[:, :, output_start_y:output_end_y,
166
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
167
+ output_start_x_tile:output_end_x_tile]
168
+
169
+ def post_process(self):
170
+ # remove extra pad
171
+ if self.mod_scale is not None:
172
+ _, _, h, w = self.output.size()
173
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
174
+ # remove prepad
175
+ if self.pre_pad != 0:
176
+ _, _, h, w = self.output.size()
177
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
178
+ return self.output
179
+
180
+ @torch.no_grad()
181
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
182
+ if isinstance(img, str):
183
+ img = cv2.imread(img)
184
+
185
+ h_input, w_input = img.shape[0:2]
186
+ # img: numpy
187
+ img = img.astype(np.float32)
188
+ if np.max(img) > 256: # 16-bit image
189
+ max_range = 65535
190
+ print('\tInput is a 16-bit image')
191
+ else:
192
+ max_range = 255
193
+ img = img / max_range
194
+ if len(img.shape) == 2: # gray image
195
+ img_mode = 'L'
196
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
197
+ elif img.shape[2] == 4: # RGBA image with alpha channel
198
+ img_mode = 'RGBA'
199
+ alpha = img[:, :, 3]
200
+ img = img[:, :, 0:3]
201
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
202
+ if alpha_upsampler == 'realesrgan':
203
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
204
+ else:
205
+ img_mode = 'RGB'
206
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
207
+
208
+ # ------------------- process image (without the alpha channel) ------------------- #
209
+ self.pre_process(img)
210
+ if self.tile_size > 0:
211
+ self.tile_process()
212
+ else:
213
+ self.process()
214
+ output_img = self.post_process()
215
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
216
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
217
+ if img_mode == 'L':
218
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
219
+
220
+ # ------------------- process the alpha channel if necessary ------------------- #
221
+ if img_mode == 'RGBA':
222
+ if alpha_upsampler == 'realesrgan':
223
+ self.pre_process(alpha)
224
+ if self.tile_size > 0:
225
+ self.tile_process()
226
+ else:
227
+ self.process()
228
+ output_alpha = self.post_process()
229
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
230
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
231
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
232
+ else: # use the cv2 resize for alpha channel
233
+ h, w = alpha.shape[0:2]
234
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
235
+
236
+ # merge the alpha channel
237
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
238
+ output_img[:, :, 3] = output_alpha
239
+
240
+ # ------------------------------ return ------------------------------ #
241
+ if max_range == 65535: # 16-bit image
242
+ output = (output_img * 65535.0).round().astype(np.uint16)
243
+ else:
244
+ output = (output_img * 255.0).round().astype(np.uint8)
245
+
246
+ if outscale is not None and outscale != float(self.scale):
247
+ output = cv2.resize(
248
+ output, (
249
+ int(w_input * outscale),
250
+ int(h_input * outscale),
251
+ ), interpolation=cv2.INTER_LANCZOS4)
252
+
253
+ return output, img_mode
254
+
255
+
256
+ class PrefetchReader(threading.Thread):
257
+ """Prefetch images.
258
+
259
+ Args:
260
+ img_list (list[str]): A image list of image paths to be read.
261
+ num_prefetch_queue (int): Number of prefetch queue.
262
+ """
263
+
264
+ def __init__(self, img_list, num_prefetch_queue):
265
+ super().__init__()
266
+ self.que = queue.Queue(num_prefetch_queue)
267
+ self.img_list = img_list
268
+
269
+ def run(self):
270
+ for img_path in self.img_list:
271
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
272
+ self.que.put(img)
273
+
274
+ self.que.put(None)
275
+
276
+ def __next__(self):
277
+ next_item = self.que.get()
278
+ if next_item is None:
279
+ raise StopIteration
280
+ return next_item
281
+
282
+ def __iter__(self):
283
+ return self
284
+
285
+
286
+ class IOConsumer(threading.Thread):
287
+
288
+ def __init__(self, opt, que, qid):
289
+ super().__init__()
290
+ self._queue = que
291
+ self.qid = qid
292
+ self.opt = opt
293
+
294
+ def run(self):
295
+ while True:
296
+ msg = self._queue.get()
297
+ if isinstance(msg, str) and msg == 'quit':
298
+ break
299
+
300
+ output = msg['output']
301
+ save_path = msg['save_path']
302
+ cv2.imwrite(save_path, output)
303
+ print(f'IO worker {self.qid} is done.')
modules/image_restoration/real_esrgan/wrapper/rrdb_net.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ import torch
3
+ from torch.nn import init as init
4
+ from torch.nn import functional as F
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Empirically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Empirically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
100
+
101
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
102
+
103
+ def forward(self, x):
104
+ if self.scale == 2:
105
+ feat = pixel_unshuffle(x, scale=2)
106
+ elif self.scale == 1:
107
+ feat = pixel_unshuffle(x, scale=4)
108
+ else:
109
+ feat = x
110
+ feat = self.conv_first(feat)
111
+ body_feat = self.conv_body(self.body(feat))
112
+ feat = feat + body_feat
113
+ # upsample
114
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
115
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
116
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
117
+ return out
118
+
119
+
120
+ def make_layer(basic_block, num_basic_block, **kwarg):
121
+ """Make layers by stacking the same blocks.
122
+
123
+ Args:
124
+ basic_block (nn.module): nn.module class for basic block.
125
+ num_basic_block (int): number of blocks.
126
+
127
+ Returns:
128
+ nn.Sequential: Stacked blocks in nn.Sequential.
129
+ """
130
+ layers = []
131
+ for _ in range(num_basic_block):
132
+ layers.append(basic_block(**kwarg))
133
+ return nn.Sequential(*layers)
134
+
135
+
136
+ def pixel_unshuffle(x, scale):
137
+ """ Pixel unshuffle.
138
+
139
+ Args:
140
+ x (Tensor): Input feature with shape (b, c, hh, hw).
141
+ scale (int): Downsample ratio.
142
+
143
+ Returns:
144
+ Tensor: the pixel unshuffled feature.
145
+ """
146
+ b, c, hh, hw = x.size()
147
+ out_channel = c * (scale**2)
148
+ assert hh % scale == 0 and hw % scale == 0
149
+ h = hh // scale
150
+ w = hw // scale
151
+ x_view = x.view(b, c, h, scale, w, scale)
152
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
153
+
154
+ @torch.no_grad()
155
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
156
+ """Initialize network weights.
157
+
158
+ Args:
159
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
160
+ scale (float): Scale initialized weights, especially for residual
161
+ blocks. Default: 1.
162
+ bias_fill (float): The value to fill bias. Default: 0
163
+ kwargs (dict): Other arguments for initialization function.
164
+ """
165
+ if not isinstance(module_list, list):
166
+ module_list = [module_list]
167
+ for module in module_list:
168
+ for m in module.modules():
169
+ if isinstance(m, nn.Conv2d):
170
+ init.kaiming_normal_(m.weight, **kwargs)
171
+ m.weight.data *= scale
172
+ if m.bias is not None:
173
+ m.bias.data.fill_(bias_fill)
174
+ elif isinstance(m, nn.Linear):
175
+ init.kaiming_normal_(m.weight, **kwargs)
176
+ m.weight.data *= scale
177
+ if m.bias is not None:
178
+ m.bias.data.fill_(bias_fill)
179
+ elif isinstance(m, _BatchNorm):
180
+ init.constant_(m.weight, 1)
181
+ if m.bias is not None:
182
+ m.bias.data.fill_(bias_fill)
modules/image_restoration/real_esrgan/wrapper/srvgg_net_compact.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class SRVGGNetCompact(nn.Module):
6
+ """A compact VGG-style network structure for super-resolution.
7
+
8
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
9
+ conducted on the HR feature space.
10
+
11
+ Args:
12
+ num_in_ch (int): Channel number of inputs. Default: 3.
13
+ num_out_ch (int): Channel number of outputs. Default: 3.
14
+ num_feat (int): Channel number of intermediate features. Default: 64.
15
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
16
+ upscale (int): Upsampling factor. Default: 4.
17
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
18
+ """
19
+
20
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
21
+ super(SRVGGNetCompact, self).__init__()
22
+ self.num_in_ch = num_in_ch
23
+ self.num_out_ch = num_out_ch
24
+ self.num_feat = num_feat
25
+ self.num_conv = num_conv
26
+ self.upscale = upscale
27
+ self.act_type = act_type
28
+
29
+ self.body = nn.ModuleList()
30
+ # the first conv
31
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
32
+ # the first activation
33
+ if act_type == 'relu':
34
+ activation = nn.ReLU(inplace=True)
35
+ elif act_type == 'prelu':
36
+ activation = nn.PReLU(num_parameters=num_feat)
37
+ elif act_type == 'leakyrelu':
38
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
39
+ self.body.append(activation)
40
+
41
+ # the body structure
42
+ for _ in range(num_conv):
43
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
44
+ # activation
45
+ if act_type == 'relu':
46
+ activation = nn.ReLU(inplace=True)
47
+ elif act_type == 'prelu':
48
+ activation = nn.PReLU(num_parameters=num_feat)
49
+ elif act_type == 'leakyrelu':
50
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
51
+ self.body.append(activation)
52
+
53
+ # the last conv
54
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
55
+ # upsample
56
+ self.upsampler = nn.PixelShuffle(upscale)
57
+
58
+ def forward(self, x):
59
+ out = x
60
+ for i in range(0, len(self.body)):
61
+ out = self.body[i](out)
62
+
63
+ out = self.upsampler(out)
64
+ # add the nearest upsampled image, so that the network learns the residual
65
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
66
+ out += base
67
+ return out
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -1,17 +1,11 @@
1
  import logging
2
- import os
3
- import cv2
4
  import time
5
  import copy
6
  import dill
7
- import torch
8
  from ultralytics import YOLO
9
  import safetensors.torch
10
  import gradio as gr
11
- from gradio_i18n import Translate, gettext as _
12
  from ultralytics.utils import LOGGER as ultralytics_logger
13
- from enum import Enum
14
- from typing import Union, List, Dict, Tuple
15
 
16
  from modules.utils.paths import *
17
  from modules.utils.image_helper import *
@@ -27,6 +21,7 @@ from modules.live_portrait.warping_network import WarpingNetwork
27
  from modules.live_portrait.motion_extractor import MotionExtractor
28
  from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
29
  from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
 
30
 
31
 
32
  class LivePortraitInferencer:
@@ -69,6 +64,11 @@ class LivePortraitInferencer:
69
  self.psi_list = None
70
  self.d_info = None
71
 
 
 
 
 
 
72
  def load_models(self,
73
  model_type: str = ModelType.HUMAN.value,
74
  progress=gr.Progress()):
@@ -161,6 +161,7 @@ class LivePortraitInferencer:
161
  sample_ratio: float = 1,
162
  sample_parts: str = SamplePart.ALL.value,
163
  crop_factor: float = 2.3,
 
164
  src_image: Optional[str] = None,
165
  sample_image: Optional[str] = None,) -> None:
166
  if isinstance(model_type, ModelType):
@@ -232,8 +233,11 @@ class LivePortraitInferencer:
232
  out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
233
 
234
  temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
235
- save_image(numpy_array=crop_out, output_path=temp_out_img_path)
236
- save_image(numpy_array=out, output_path=out_img_path)
 
 
 
237
 
238
  return out
239
  except Exception as e:
@@ -244,6 +248,7 @@ class LivePortraitInferencer:
244
  retargeting_eyes: float = 1,
245
  retargeting_mouth: float = 1,
246
  crop_factor: float = 2.3,
 
247
  src_image: Optional[str] = None,
248
  driving_vid_path: Optional[str] = None,
249
  progress: gr.Progress = gr.Progress()
@@ -317,11 +322,18 @@ class LivePortraitInferencer:
317
  np.uint8)
318
 
319
  out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
320
- save_image(out, out_frame_path)
 
 
 
321
 
322
  progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
323
 
324
- video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos"))
 
 
 
 
325
 
326
  return video_path
327
  except Exception as e:
 
1
  import logging
 
 
2
  import time
3
  import copy
4
  import dill
 
5
  from ultralytics import YOLO
6
  import safetensors.torch
7
  import gradio as gr
 
8
  from ultralytics.utils import LOGGER as ultralytics_logger
 
 
9
 
10
  from modules.utils.paths import *
11
  from modules.utils.image_helper import *
 
21
  from modules.live_portrait.motion_extractor import MotionExtractor
22
  from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
23
  from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
24
+ from modules.image_restoration.real_esrgan.real_esrgan_inferencer import RealESRGANInferencer
25
 
26
 
27
  class LivePortraitInferencer:
 
64
  self.psi_list = None
65
  self.d_info = None
66
 
67
+ self.resrgan_inferencer = RealESRGANInferencer(
68
+ model_dir=os.path.join(self.model_dir, "RealESRGAN"),
69
+ output_dir=self.output_dir
70
+ )
71
+
72
  def load_models(self,
73
  model_type: str = ModelType.HUMAN.value,
74
  progress=gr.Progress()):
 
161
  sample_ratio: float = 1,
162
  sample_parts: str = SamplePart.ALL.value,
163
  crop_factor: float = 2.3,
164
+ enable_image_restoration: bool = False,
165
  src_image: Optional[str] = None,
166
  sample_image: Optional[str] = None,) -> None:
167
  if isinstance(model_type, ModelType):
 
233
  out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
234
 
235
  temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
236
+ cropped_out_img_path = save_image(numpy_array=crop_out, output_path=temp_out_img_path)
237
+ out_img_path = save_image(numpy_array=out, output_path=out_img_path)
238
+
239
+ if enable_image_restoration:
240
+ out = self.resrgan_inferencer.restore_image(out_img_path)
241
 
242
  return out
243
  except Exception as e:
 
248
  retargeting_eyes: float = 1,
249
  retargeting_mouth: float = 1,
250
  crop_factor: float = 2.3,
251
+ enable_image_restoration: bool = False,
252
  src_image: Optional[str] = None,
253
  driving_vid_path: Optional[str] = None,
254
  progress: gr.Progress = gr.Progress()
 
322
  np.uint8)
323
 
324
  out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
325
+ out_frame_path = save_image(out, out_frame_path)
326
+
327
+ if enable_image_restoration:
328
+ out_frame_path = self.resrgan_inferencer.restore_image(out_frame_path)
329
 
330
  progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
331
 
332
+ video_path = create_video_from_frames(
333
+ TEMP_VIDEO_OUT_FRAMES_DIR,
334
+ frame_rate=vid_info.frame_rate,
335
+ output_dir=os.path.join(self.output_dir, "videos")
336
+ )
337
 
338
  return video_path
339
  except Exception as e:
modules/utils/paths.py CHANGED
@@ -2,9 +2,10 @@ import functools
2
  import os
3
 
4
 
5
- PROJECT_ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "..")
6
  MODELS_DIR = os.path.join(PROJECT_ROOT_DIR, "models")
7
  MODELS_ANIMAL_DIR = os.path.join(MODELS_DIR, "animal")
 
8
  OUTPUTS_DIR = os.path.join(PROJECT_ROOT_DIR, "outputs")
9
  OUTPUTS_VIDEOS_DIR = os.path.join(OUTPUTS_DIR, "videos")
10
  TEMP_DIR = os.path.join(OUTPUTS_DIR, "temp")
@@ -29,6 +30,9 @@ MODEL_ANIMAL_PATHS = {
29
  # Just animal detection model not the face, needs better model
30
  "yolo_v5s_animal_det": os.path.join(MODELS_ANIMAL_DIR, "yolo_v5s_animal_det.n2x")
31
  }
 
 
 
32
  MASK_TEMPLATES = os.path.join(PROJECT_ROOT_DIR, "modules", "utils", "resources", "mask_template.png")
33
  I18N_YAML_PATH = os.path.join(PROJECT_ROOT_DIR, "i18n", "translation.yaml")
34
 
@@ -52,6 +56,7 @@ def init_dirs():
52
  for dir_path in [
53
  MODELS_DIR,
54
  MODELS_ANIMAL_DIR,
 
55
  OUTPUTS_DIR,
56
  EXP_OUTPUT_DIR,
57
  TEMP_DIR,
 
2
  import os
3
 
4
 
5
+ PROJECT_ROOT_DIR = os.path.normpath(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", ".."))
6
  MODELS_DIR = os.path.join(PROJECT_ROOT_DIR, "models")
7
  MODELS_ANIMAL_DIR = os.path.join(MODELS_DIR, "animal")
8
+ MODELS_REAL_ESRGAN_DIR = os.path.join(MODELS_DIR, "RealESRGAN")
9
  OUTPUTS_DIR = os.path.join(PROJECT_ROOT_DIR, "outputs")
10
  OUTPUTS_VIDEOS_DIR = os.path.join(OUTPUTS_DIR, "videos")
11
  TEMP_DIR = os.path.join(OUTPUTS_DIR, "temp")
 
30
  # Just animal detection model not the face, needs better model
31
  "yolo_v5s_animal_det": os.path.join(MODELS_ANIMAL_DIR, "yolo_v5s_animal_det.n2x")
32
  }
33
+ MODEL_REAL_ESRGAN_PATH = {
34
+ "realesr-general-x4v3": os.path.join(MODELS_REAL_ESRGAN_DIR, "realesr-general-x4v3.pth")
35
+ }
36
  MASK_TEMPLATES = os.path.join(PROJECT_ROOT_DIR, "modules", "utils", "resources", "mask_template.png")
37
  I18N_YAML_PATH = os.path.join(PROJECT_ROOT_DIR, "i18n", "translation.yaml")
38
 
 
56
  for dir_path in [
57
  MODELS_DIR,
58
  MODELS_ANIMAL_DIR,
59
+ MODELS_REAL_ESRGAN_DIR,
60
  OUTPUTS_DIR,
61
  EXP_OUTPUT_DIR,
62
  TEMP_DIR,
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  --extra-index-url https://download.pytorch.org/whl/cu124
2
  torch
3
  torchvision
@@ -15,7 +16,6 @@ dill
15
  gradio
16
  gradio-i18n
17
 
18
-
19
  # Tests
20
  # pytest
21
  # scikit-image
 
1
+ # AdvancedLivePortrait
2
  --extra-index-url https://download.pytorch.org/whl/cu124
3
  torch
4
  torchvision
 
16
  gradio
17
  gradio-i18n
18
 
 
19
  # Tests
20
  # pytest
21
  # scikit-image
tests/test_image_restoration.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+
4
+ from test_config import *
5
+ from modules.live_portrait.live_portrait_inferencer import LivePortraitInferencer
6
+
7
+
8
+ @pytest.mark.parametrize(
9
+ "input_image",
10
+ [
11
+ TEST_IMAGE_PATH
12
+ ]
13
+ )
14
+ def test_image_restoration(
15
+ input_image: str,
16
+ ):
17
+ if not os.path.exists(TEST_IMAGE_PATH):
18
+ download_image(
19
+ TEST_IMAGE_URL,
20
+ TEST_IMAGE_PATH
21
+ )
22
+
23
+ inferencer = LivePortraitInferencer()
24
+
25
+ restored_output = inferencer.resrgan_inferencer.restore_image(
26
+ input_image,
27
+ overwrite=False
28
+ )
29
+
30
+ assert os.path.exists(restored_output)
31
+ assert are_images_different(input_image, restored_output)