Spaces:
Runtime error
Runtime error
support MPS (Mac M1) device; Happy Lunar New Year!
Browse files- README.md +1 -1
- basicsr/setup.py +3 -2
- basicsr/utils/misc.py +24 -1
- basicsr/utils/realesrgan_utils.py +10 -7
- facelib/detection/retinaface/retinaface.py +4 -2
- facelib/detection/yolov5face/face_detector.py +7 -8
- facelib/utils/face_restoration_helper.py +3 -1
- inference_codeformer.py +9 -10
- web-demos/hugging_face/app.py +9 -6
- web-demos/replicate/predict.py +5 -3
README.md
CHANGED
@@ -146,4 +146,4 @@ This project is licensed under <a rel="license" href="https://github.com/sczhou/
|
|
146 |
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
|
147 |
|
148 |
### Contact
|
149 |
-
If you have any
|
|
|
146 |
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
|
147 |
|
148 |
### Contact
|
149 |
+
If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
|
basicsr/setup.py
CHANGED
@@ -6,8 +6,8 @@ import os
|
|
6 |
import subprocess
|
7 |
import sys
|
8 |
import time
|
9 |
-
import torch
|
10 |
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
|
|
11 |
|
12 |
version_file = './basicsr/version.py'
|
13 |
|
@@ -87,7 +87,8 @@ def make_cuda_ext(name, module, sources, sources_cuda=None):
|
|
87 |
define_macros = []
|
88 |
extra_compile_args = {'cxx': []}
|
89 |
|
90 |
-
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
|
|
|
91 |
define_macros += [('WITH_CUDA', None)]
|
92 |
extension = CUDAExtension
|
93 |
extra_compile_args['nvcc'] = [
|
|
|
6 |
import subprocess
|
7 |
import sys
|
8 |
import time
|
|
|
9 |
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
10 |
+
from utils.misc import gpu_is_available
|
11 |
|
12 |
version_file = './basicsr/version.py'
|
13 |
|
|
|
87 |
define_macros = []
|
88 |
extra_compile_args = {'cxx': []}
|
89 |
|
90 |
+
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
|
91 |
+
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
|
92 |
define_macros += [('WITH_CUDA', None)]
|
93 |
extension = CUDAExtension
|
94 |
extra_compile_args['nvcc'] = [
|
basicsr/utils/misc.py
CHANGED
@@ -1,13 +1,36 @@
|
|
1 |
-
import numpy as np
|
2 |
import os
|
|
|
3 |
import random
|
4 |
import time
|
5 |
import torch
|
|
|
6 |
from os import path as osp
|
7 |
|
8 |
from .dist_util import master_only
|
9 |
from .logger import get_root_logger
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def set_random_seed(seed):
|
13 |
"""Set random seeds."""
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
import random
|
4 |
import time
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
from os import path as osp
|
8 |
|
9 |
from .dist_util import master_only
|
10 |
from .logger import get_root_logger
|
11 |
|
12 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
13 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
14 |
+
|
15 |
+
def gpu_is_available():
|
16 |
+
if IS_HIGH_VERSION:
|
17 |
+
if torch.backends.mps.is_available():
|
18 |
+
return True
|
19 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
20 |
+
|
21 |
+
def get_device(gpu_id=None):
|
22 |
+
if gpu_id is None:
|
23 |
+
gpu_str = ''
|
24 |
+
elif isinstance(gpu_id, int):
|
25 |
+
gpu_str = f':{gpu_id}'
|
26 |
+
else:
|
27 |
+
raise TypeError('Input should be int value.')
|
28 |
+
|
29 |
+
if IS_HIGH_VERSION:
|
30 |
+
if torch.backends.mps.is_available():
|
31 |
+
return torch.device('mps'+gpu_str)
|
32 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
33 |
+
|
34 |
|
35 |
def set_random_seed(seed):
|
36 |
"""Set random seeds."""
|
basicsr/utils/realesrgan_utils.py
CHANGED
@@ -5,12 +5,12 @@ import os
|
|
5 |
import queue
|
6 |
import threading
|
7 |
import torch
|
8 |
-
from basicsr.utils.download_util import load_file_from_url
|
9 |
from torch.nn import functional as F
|
|
|
|
|
10 |
|
11 |
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
12 |
|
13 |
-
|
14 |
class RealESRGANer():
|
15 |
"""A helper class for upsampling images with RealESRGAN.
|
16 |
|
@@ -44,11 +44,14 @@ class RealESRGANer():
|
|
44 |
self.half = half
|
45 |
|
46 |
# initialize model
|
47 |
-
if gpu_id:
|
48 |
-
|
49 |
-
|
50 |
-
else:
|
51 |
-
|
|
|
|
|
|
|
52 |
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
53 |
if model_path.startswith('https://'):
|
54 |
model_path = load_file_from_url(
|
|
|
5 |
import queue
|
6 |
import threading
|
7 |
import torch
|
|
|
8 |
from torch.nn import functional as F
|
9 |
+
from basicsr.utils.download_util import load_file_from_url
|
10 |
+
from basicsr.utils.misc import get_device
|
11 |
|
12 |
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
13 |
|
|
|
14 |
class RealESRGANer():
|
15 |
"""A helper class for upsampling images with RealESRGAN.
|
16 |
|
|
|
44 |
self.half = half
|
45 |
|
46 |
# initialize model
|
47 |
+
# if gpu_id:
|
48 |
+
# self.device = torch.device(
|
49 |
+
# f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
50 |
+
# else:
|
51 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
52 |
+
|
53 |
+
self.device = get_device(gpu_id) if device is None else device
|
54 |
+
|
55 |
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
56 |
if model_path.startswith('https://'):
|
57 |
model_path = load_file_from_url(
|
facelib/detection/retinaface/retinaface.py
CHANGED
@@ -11,11 +11,13 @@ from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, m
|
|
11 |
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
|
12 |
py_cpu_nms)
|
13 |
|
14 |
-
|
|
|
|
|
15 |
|
16 |
|
17 |
def generate_config(network_name):
|
18 |
-
|
19 |
cfg_mnet = {
|
20 |
'name': 'mobilenet0.25',
|
21 |
'min_sizes': [[16, 32], [64, 128], [256, 512]],
|
|
|
11 |
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
|
12 |
py_cpu_nms)
|
13 |
|
14 |
+
from basicsr.utils.misc import get_device
|
15 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
device = get_device()
|
17 |
|
18 |
|
19 |
def generate_config(network_name):
|
20 |
+
|
21 |
cfg_mnet = {
|
22 |
'name': 'mobilenet0.25',
|
23 |
'min_sizes': [[16, 32], [64, 128], [256, 512]],
|
facelib/detection/yolov5face/face_detector.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
-
import copy
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
import cv2
|
6 |
-
import
|
|
|
7 |
import torch
|
8 |
-
|
9 |
|
10 |
-
from
|
11 |
from facelib.detection.yolov5face.models.yolo import Model
|
12 |
from facelib.detection.yolov5face.utils.datasets import letterbox
|
13 |
from facelib.detection.yolov5face.utils.general import (
|
@@ -17,7 +14,9 @@ from facelib.detection.yolov5face.utils.general import (
|
|
17 |
scale_coords_landmarks,
|
18 |
)
|
19 |
|
20 |
-
IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9
|
|
|
|
|
21 |
|
22 |
|
23 |
def isListempty(inList):
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
+
import copy
|
3 |
+
import re
|
4 |
import torch
|
5 |
+
import numpy as np
|
6 |
|
7 |
+
from pathlib import Path
|
8 |
from facelib.detection.yolov5face.models.yolo import Model
|
9 |
from facelib.detection.yolov5face.utils.datasets import letterbox
|
10 |
from facelib.detection.yolov5face.utils.general import (
|
|
|
14 |
scale_coords_landmarks,
|
15 |
)
|
16 |
|
17 |
+
# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
|
18 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
19 |
+
torch.__version__)[0][:3])] >= [1, 9, 0]
|
20 |
|
21 |
|
22 |
def isListempty(inList):
|
facelib/utils/face_restoration_helper.py
CHANGED
@@ -7,6 +7,7 @@ from torchvision.transforms.functional import normalize
|
|
7 |
from facelib.detection import init_detection_model
|
8 |
from facelib.parsing import init_parsing_model
|
9 |
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
|
|
10 |
|
11 |
|
12 |
def get_largest_face(det_faces, h, w):
|
@@ -97,7 +98,8 @@ class FaceRestoreHelper(object):
|
|
97 |
self.pad_input_imgs = []
|
98 |
|
99 |
if device is None:
|
100 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
101 |
else:
|
102 |
self.device = device
|
103 |
|
|
|
7 |
from facelib.detection import init_detection_model
|
8 |
from facelib.parsing import init_parsing_model
|
9 |
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
10 |
+
from basicsr.utils.misc import get_device
|
11 |
|
12 |
|
13 |
def get_largest_face(det_faces, h, w):
|
|
|
98 |
self.pad_input_imgs = []
|
99 |
|
100 |
if device is None:
|
101 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
102 |
+
self.device = get_device()
|
103 |
else:
|
104 |
self.device = device
|
105 |
|
inference_codeformer.py
CHANGED
@@ -6,9 +6,9 @@ import torch
|
|
6 |
from torchvision.transforms.functional import normalize
|
7 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
8 |
from basicsr.utils.download_util import load_file_from_url
|
|
|
9 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
10 |
from facelib.utils.misc import is_gray
|
11 |
-
import torch.nn.functional as F
|
12 |
|
13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
14 |
|
@@ -19,9 +19,7 @@ pretrain_model_url = {
|
|
19 |
def set_realesrgan():
|
20 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
22 |
-
|
23 |
-
cuda_is_available = torch.cuda.is_available()
|
24 |
-
half = True if cuda_is_available else False
|
25 |
model = RRDBNet(
|
26 |
num_in_ch=3,
|
27 |
num_out_ch=3,
|
@@ -37,10 +35,10 @@ def set_realesrgan():
|
|
37 |
tile=args.bg_tile,
|
38 |
tile_pad=40,
|
39 |
pre_pad=0,
|
40 |
-
half=
|
41 |
)
|
42 |
|
43 |
-
if not
|
44 |
import warnings
|
45 |
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
|
46 |
'The unoptimized RealESRGAN is slow on CPU. '
|
@@ -49,7 +47,8 @@ def set_realesrgan():
|
|
49 |
return upsampler
|
50 |
|
51 |
if __name__ == '__main__':
|
52 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
53 |
parser = argparse.ArgumentParser()
|
54 |
|
55 |
parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
|
@@ -79,10 +78,10 @@ if __name__ == '__main__':
|
|
79 |
# ------------------------ input & output ------------------------
|
80 |
w = args.fidelity_weight
|
81 |
input_video = False
|
82 |
-
if args.input_path.endswith(('jpg', 'png')): # input single img path
|
83 |
input_img_list = [args.input_path]
|
84 |
result_root = f'results/test_img_{w}'
|
85 |
-
elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path
|
86 |
from basicsr.utils.video_util import VideoReader, VideoWriter
|
87 |
input_img_list = []
|
88 |
vidreader = VideoReader(args.input_path)
|
@@ -100,7 +99,7 @@ if __name__ == '__main__':
|
|
100 |
if args.input_path.endswith('/'): # solve when path ends with /
|
101 |
args.input_path = args.input_path[:-1]
|
102 |
# scan all the jpg and png images
|
103 |
-
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[
|
104 |
result_root = f'results/{os.path.basename(args.input_path)}_{w}'
|
105 |
|
106 |
if not args.output_path is None: # set output path
|
|
|
6 |
from torchvision.transforms.functional import normalize
|
7 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
8 |
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
from basicsr.utils.misc import gpu_is_available, get_device
|
10 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
11 |
from facelib.utils.misc import is_gray
|
|
|
12 |
|
13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
14 |
|
|
|
19 |
def set_realesrgan():
|
20 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
22 |
+
|
|
|
|
|
23 |
model = RRDBNet(
|
24 |
num_in_ch=3,
|
25 |
num_out_ch=3,
|
|
|
35 |
tile=args.bg_tile,
|
36 |
tile_pad=40,
|
37 |
pre_pad=0,
|
38 |
+
half=torch.cuda.is_available(), # need to set False in CPU/MPS mode
|
39 |
)
|
40 |
|
41 |
+
if not gpu_is_available(): # CPU
|
42 |
import warnings
|
43 |
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
|
44 |
'The unoptimized RealESRGAN is slow on CPU. '
|
|
|
47 |
return upsampler
|
48 |
|
49 |
if __name__ == '__main__':
|
50 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
51 |
+
device = get_device()
|
52 |
parser = argparse.ArgumentParser()
|
53 |
|
54 |
parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
|
|
|
78 |
# ------------------------ input & output ------------------------
|
79 |
w = args.fidelity_weight
|
80 |
input_video = False
|
81 |
+
if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
82 |
input_img_list = [args.input_path]
|
83 |
result_root = f'results/test_img_{w}'
|
84 |
+
elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
85 |
from basicsr.utils.video_util import VideoReader, VideoWriter
|
86 |
input_img_list = []
|
87 |
vidreader = VideoReader(args.input_path)
|
|
|
99 |
if args.input_path.endswith('/'): # solve when path ends with /
|
100 |
args.input_path = args.input_path[:-1]
|
101 |
# scan all the jpg and png images
|
102 |
+
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
|
103 |
result_root = f'results/{os.path.basename(args.input_path)}_{w}'
|
104 |
|
105 |
if not args.output_path is None: # set output path
|
web-demos/hugging_face/app.py
CHANGED
@@ -13,15 +13,16 @@ import gradio as gr
|
|
13 |
|
14 |
from torchvision.transforms.functional import normalize
|
15 |
|
|
|
16 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
17 |
from basicsr.utils.download_util import load_file_from_url
|
18 |
-
from
|
19 |
-
from facelib.utils.misc import is_gray
|
20 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
22 |
-
|
23 |
from basicsr.utils.registry import ARCH_REGISTRY
|
24 |
|
|
|
|
|
|
|
25 |
|
26 |
os.system("pip freeze")
|
27 |
|
@@ -65,7 +66,8 @@ def imread(img_path):
|
|
65 |
|
66 |
# set enhancer with RealESRGAN
|
67 |
def set_realesrgan():
|
68 |
-
half = True if torch.cuda.is_available() else False
|
|
|
69 |
model = RRDBNet(
|
70 |
num_in_ch=3,
|
71 |
num_out_ch=3,
|
@@ -86,7 +88,8 @@ def set_realesrgan():
|
|
86 |
return upsampler
|
87 |
|
88 |
upsampler = set_realesrgan()
|
89 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
90 |
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
91 |
dim_embd=512,
|
92 |
codebook_size=1024,
|
|
|
13 |
|
14 |
from torchvision.transforms.functional import normalize
|
15 |
|
16 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
17 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
18 |
from basicsr.utils.download_util import load_file_from_url
|
19 |
+
from basicsr.utils.misc import gpu_is_available, get_device
|
|
|
|
|
20 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
|
|
21 |
from basicsr.utils.registry import ARCH_REGISTRY
|
22 |
|
23 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
24 |
+
from facelib.utils.misc import is_gray
|
25 |
+
|
26 |
|
27 |
os.system("pip freeze")
|
28 |
|
|
|
66 |
|
67 |
# set enhancer with RealESRGAN
|
68 |
def set_realesrgan():
|
69 |
+
# half = True if torch.cuda.is_available() else False
|
70 |
+
half = True if gpu_is_available() else False
|
71 |
model = RRDBNet(
|
72 |
num_in_ch=3,
|
73 |
num_out_ch=3,
|
|
|
88 |
return upsampler
|
89 |
|
90 |
upsampler = set_realesrgan()
|
91 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
92 |
+
device = get_device()
|
93 |
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
94 |
dim_embd=512,
|
95 |
codebook_size=1024,
|
web-demos/replicate/predict.py
CHANGED
@@ -14,12 +14,13 @@ try:
|
|
14 |
except Exception:
|
15 |
print('please install cog package')
|
16 |
|
17 |
-
from basicsr.utils import imwrite, img2tensor, tensor2img
|
18 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
|
19 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
|
|
20 |
from basicsr.utils.registry import ARCH_REGISTRY
|
21 |
-
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
22 |
|
|
|
23 |
|
24 |
class Predictor(BasePredictor):
|
25 |
def setup(self):
|
@@ -159,7 +160,8 @@ def imread(img_path):
|
|
159 |
|
160 |
|
161 |
def set_realesrgan():
|
162 |
-
if not torch.cuda.is_available(): # CPU
|
|
|
163 |
import warnings
|
164 |
|
165 |
warnings.warn(
|
|
|
14 |
except Exception:
|
15 |
print('please install cog package')
|
16 |
|
|
|
17 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
18 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
19 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
20 |
+
from basicsr.utils.misc import gpu_is_available
|
21 |
from basicsr.utils.registry import ARCH_REGISTRY
|
|
|
22 |
|
23 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
24 |
|
25 |
class Predictor(BasePredictor):
|
26 |
def setup(self):
|
|
|
160 |
|
161 |
|
162 |
def set_realesrgan():
|
163 |
+
# if not torch.cuda.is_available(): # CPU
|
164 |
+
if not gpu_is_available(): # CPU
|
165 |
import warnings
|
166 |
|
167 |
warnings.warn(
|