sczhou commited on
Commit
07c8cc6
·
1 Parent(s): 7a584fd

support MPS (Mac M1) device; Happy Lunar New Year!

Browse files
README.md CHANGED
@@ -146,4 +146,4 @@ This project is licensed under <a rel="license" href="https://github.com/sczhou/
146
  This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
147
 
148
  ### Contact
149
- If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`.
 
146
  This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
147
 
148
  ### Contact
149
+ If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
basicsr/setup.py CHANGED
@@ -6,8 +6,8 @@ import os
6
  import subprocess
7
  import sys
8
  import time
9
- import torch
10
  from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
 
11
 
12
  version_file = './basicsr/version.py'
13
 
@@ -87,7 +87,8 @@ def make_cuda_ext(name, module, sources, sources_cuda=None):
87
  define_macros = []
88
  extra_compile_args = {'cxx': []}
89
 
90
- if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
 
91
  define_macros += [('WITH_CUDA', None)]
92
  extension = CUDAExtension
93
  extra_compile_args['nvcc'] = [
 
6
  import subprocess
7
  import sys
8
  import time
 
9
  from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
10
+ from utils.misc import gpu_is_available
11
 
12
  version_file = './basicsr/version.py'
13
 
 
87
  define_macros = []
88
  extra_compile_args = {'cxx': []}
89
 
90
+ # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91
+ if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
92
  define_macros += [('WITH_CUDA', None)]
93
  extension = CUDAExtension
94
  extra_compile_args['nvcc'] = [
basicsr/utils/misc.py CHANGED
@@ -1,13 +1,36 @@
1
- import numpy as np
2
  import os
 
3
  import random
4
  import time
5
  import torch
 
6
  from os import path as osp
7
 
8
  from .dist_util import master_only
9
  from .logger import get_root_logger
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def set_random_seed(seed):
13
  """Set random seeds."""
 
 
1
  import os
2
+ import re
3
  import random
4
  import time
5
  import torch
6
+ import numpy as np
7
  from os import path as osp
8
 
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
13
+ torch.__version__)[0][:3])] >= [1, 12, 0]
14
+
15
+ def gpu_is_available():
16
+ if IS_HIGH_VERSION:
17
+ if torch.backends.mps.is_available():
18
+ return True
19
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
20
+
21
+ def get_device(gpu_id=None):
22
+ if gpu_id is None:
23
+ gpu_str = ''
24
+ elif isinstance(gpu_id, int):
25
+ gpu_str = f':{gpu_id}'
26
+ else:
27
+ raise TypeError('Input should be int value.')
28
+
29
+ if IS_HIGH_VERSION:
30
+ if torch.backends.mps.is_available():
31
+ return torch.device('mps'+gpu_str)
32
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
33
+
34
 
35
  def set_random_seed(seed):
36
  """Set random seeds."""
basicsr/utils/realesrgan_utils.py CHANGED
@@ -5,12 +5,12 @@ import os
5
  import queue
6
  import threading
7
  import torch
8
- from basicsr.utils.download_util import load_file_from_url
9
  from torch.nn import functional as F
 
 
10
 
11
  # ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
 
13
-
14
  class RealESRGANer():
15
  """A helper class for upsampling images with RealESRGAN.
16
 
@@ -44,11 +44,14 @@ class RealESRGANer():
44
  self.half = half
45
 
46
  # initialize model
47
- if gpu_id:
48
- self.device = torch.device(
49
- f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
50
- else:
51
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
 
 
 
52
  # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
53
  if model_path.startswith('https://'):
54
  model_path = load_file_from_url(
 
5
  import queue
6
  import threading
7
  import torch
 
8
  from torch.nn import functional as F
9
+ from basicsr.utils.download_util import load_file_from_url
10
+ from basicsr.utils.misc import get_device
11
 
12
  # ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
 
 
14
  class RealESRGANer():
15
  """A helper class for upsampling images with RealESRGAN.
16
 
 
44
  self.half = half
45
 
46
  # initialize model
47
+ # if gpu_id:
48
+ # self.device = torch.device(
49
+ # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
50
+ # else:
51
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
52
+
53
+ self.device = get_device(gpu_id) if device is None else device
54
+
55
  # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
56
  if model_path.startswith('https://'):
57
  model_path = load_file_from_url(
facelib/detection/retinaface/retinaface.py CHANGED
@@ -11,11 +11,13 @@ from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, m
11
  from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
12
  py_cpu_nms)
13
 
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
15
 
16
 
17
  def generate_config(network_name):
18
-
19
  cfg_mnet = {
20
  'name': 'mobilenet0.25',
21
  'min_sizes': [[16, 32], [64, 128], [256, 512]],
 
11
  from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
12
  py_cpu_nms)
13
 
14
+ from basicsr.utils.misc import get_device
15
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ device = get_device()
17
 
18
 
19
  def generate_config(network_name):
20
+
21
  cfg_mnet = {
22
  'name': 'mobilenet0.25',
23
  'min_sizes': [[16, 32], [64, 128], [256, 512]],
facelib/detection/yolov5face/face_detector.py CHANGED
@@ -1,13 +1,10 @@
1
- import copy
2
- import os
3
- from pathlib import Path
4
-
5
  import cv2
6
- import numpy as np
 
7
  import torch
8
- from torch import nn
9
 
10
- from facelib.detection.yolov5face.models.common import Conv
11
  from facelib.detection.yolov5face.models.yolo import Model
12
  from facelib.detection.yolov5face.utils.datasets import letterbox
13
  from facelib.detection.yolov5face.utils.general import (
@@ -17,7 +14,9 @@ from facelib.detection.yolov5face.utils.general import (
17
  scale_coords_landmarks,
18
  )
19
 
20
- IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9, 0)
 
 
21
 
22
 
23
  def isListempty(inList):
 
 
 
 
 
1
  import cv2
2
+ import copy
3
+ import re
4
  import torch
5
+ import numpy as np
6
 
7
+ from pathlib import Path
8
  from facelib.detection.yolov5face.models.yolo import Model
9
  from facelib.detection.yolov5face.utils.datasets import letterbox
10
  from facelib.detection.yolov5face.utils.general import (
 
14
  scale_coords_landmarks,
15
  )
16
 
17
+ # IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
18
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
19
+ torch.__version__)[0][:3])] >= [1, 9, 0]
20
 
21
 
22
  def isListempty(inList):
facelib/utils/face_restoration_helper.py CHANGED
@@ -7,6 +7,7 @@ from torchvision.transforms.functional import normalize
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
  from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
 
10
 
11
 
12
  def get_largest_face(det_faces, h, w):
@@ -97,7 +98,8 @@ class FaceRestoreHelper(object):
97
  self.pad_input_imgs = []
98
 
99
  if device is None:
100
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
101
  else:
102
  self.device = device
103
 
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
  from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
10
+ from basicsr.utils.misc import get_device
11
 
12
 
13
  def get_largest_face(det_faces, h, w):
 
98
  self.pad_input_imgs = []
99
 
100
  if device is None:
101
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ self.device = get_device()
103
  else:
104
  self.device = device
105
 
inference_codeformer.py CHANGED
@@ -6,9 +6,9 @@ import torch
6
  from torchvision.transforms.functional import normalize
7
  from basicsr.utils import imwrite, img2tensor, tensor2img
8
  from basicsr.utils.download_util import load_file_from_url
 
9
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
10
  from facelib.utils.misc import is_gray
11
- import torch.nn.functional as F
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
@@ -19,9 +19,7 @@ pretrain_model_url = {
19
  def set_realesrgan():
20
  from basicsr.archs.rrdbnet_arch import RRDBNet
21
  from basicsr.utils.realesrgan_utils import RealESRGANer
22
-
23
- cuda_is_available = torch.cuda.is_available()
24
- half = True if cuda_is_available else False
25
  model = RRDBNet(
26
  num_in_ch=3,
27
  num_out_ch=3,
@@ -37,10 +35,10 @@ def set_realesrgan():
37
  tile=args.bg_tile,
38
  tile_pad=40,
39
  pre_pad=0,
40
- half=half, # need to set False in CPU mode
41
  )
42
 
43
- if not cuda_is_available: # CPU
44
  import warnings
45
  warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
46
  'The unoptimized RealESRGAN is slow on CPU. '
@@ -49,7 +47,8 @@ def set_realesrgan():
49
  return upsampler
50
 
51
  if __name__ == '__main__':
52
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
53
  parser = argparse.ArgumentParser()
54
 
55
  parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
@@ -79,10 +78,10 @@ if __name__ == '__main__':
79
  # ------------------------ input & output ------------------------
80
  w = args.fidelity_weight
81
  input_video = False
82
- if args.input_path.endswith(('jpg', 'png')): # input single img path
83
  input_img_list = [args.input_path]
84
  result_root = f'results/test_img_{w}'
85
- elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path
86
  from basicsr.utils.video_util import VideoReader, VideoWriter
87
  input_img_list = []
88
  vidreader = VideoReader(args.input_path)
@@ -100,7 +99,7 @@ if __name__ == '__main__':
100
  if args.input_path.endswith('/'): # solve when path ends with /
101
  args.input_path = args.input_path[:-1]
102
  # scan all the jpg and png images
103
- input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jp][pn]g')))
104
  result_root = f'results/{os.path.basename(args.input_path)}_{w}'
105
 
106
  if not args.output_path is None: # set output path
 
6
  from torchvision.transforms.functional import normalize
7
  from basicsr.utils import imwrite, img2tensor, tensor2img
8
  from basicsr.utils.download_util import load_file_from_url
9
+ from basicsr.utils.misc import gpu_is_available, get_device
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
  from facelib.utils.misc import is_gray
 
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
 
19
  def set_realesrgan():
20
  from basicsr.archs.rrdbnet_arch import RRDBNet
21
  from basicsr.utils.realesrgan_utils import RealESRGANer
22
+
 
 
23
  model = RRDBNet(
24
  num_in_ch=3,
25
  num_out_ch=3,
 
35
  tile=args.bg_tile,
36
  tile_pad=40,
37
  pre_pad=0,
38
+ half=torch.cuda.is_available(), # need to set False in CPU/MPS mode
39
  )
40
 
41
+ if not gpu_is_available(): # CPU
42
  import warnings
43
  warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
44
  'The unoptimized RealESRGAN is slow on CPU. '
 
47
  return upsampler
48
 
49
  if __name__ == '__main__':
50
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ device = get_device()
52
  parser = argparse.ArgumentParser()
53
 
54
  parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
 
78
  # ------------------------ input & output ------------------------
79
  w = args.fidelity_weight
80
  input_video = False
81
+ if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
82
  input_img_list = [args.input_path]
83
  result_root = f'results/test_img_{w}'
84
+ elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
85
  from basicsr.utils.video_util import VideoReader, VideoWriter
86
  input_img_list = []
87
  vidreader = VideoReader(args.input_path)
 
99
  if args.input_path.endswith('/'): # solve when path ends with /
100
  args.input_path = args.input_path[:-1]
101
  # scan all the jpg and png images
102
+ input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
103
  result_root = f'results/{os.path.basename(args.input_path)}_{w}'
104
 
105
  if not args.output_path is None: # set output path
web-demos/hugging_face/app.py CHANGED
@@ -13,15 +13,16 @@ import gradio as gr
13
 
14
  from torchvision.transforms.functional import normalize
15
 
 
16
  from basicsr.utils import imwrite, img2tensor, tensor2img
17
  from basicsr.utils.download_util import load_file_from_url
18
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
- from facelib.utils.misc import is_gray
20
- from basicsr.archs.rrdbnet_arch import RRDBNet
21
  from basicsr.utils.realesrgan_utils import RealESRGANer
22
-
23
  from basicsr.utils.registry import ARCH_REGISTRY
24
 
 
 
 
25
 
26
  os.system("pip freeze")
27
 
@@ -65,7 +66,8 @@ def imread(img_path):
65
 
66
  # set enhancer with RealESRGAN
67
  def set_realesrgan():
68
- half = True if torch.cuda.is_available() else False
 
69
  model = RRDBNet(
70
  num_in_ch=3,
71
  num_out_ch=3,
@@ -86,7 +88,8 @@ def set_realesrgan():
86
  return upsampler
87
 
88
  upsampler = set_realesrgan()
89
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
90
  codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
91
  dim_embd=512,
92
  codebook_size=1024,
 
13
 
14
  from torchvision.transforms.functional import normalize
15
 
16
+ from basicsr.archs.rrdbnet_arch import RRDBNet
17
  from basicsr.utils import imwrite, img2tensor, tensor2img
18
  from basicsr.utils.download_util import load_file_from_url
19
+ from basicsr.utils.misc import gpu_is_available, get_device
 
 
20
  from basicsr.utils.realesrgan_utils import RealESRGANer
 
21
  from basicsr.utils.registry import ARCH_REGISTRY
22
 
23
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
24
+ from facelib.utils.misc import is_gray
25
+
26
 
27
  os.system("pip freeze")
28
 
 
66
 
67
  # set enhancer with RealESRGAN
68
  def set_realesrgan():
69
+ # half = True if torch.cuda.is_available() else False
70
+ half = True if gpu_is_available() else False
71
  model = RRDBNet(
72
  num_in_ch=3,
73
  num_out_ch=3,
 
88
  return upsampler
89
 
90
  upsampler = set_realesrgan()
91
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ device = get_device()
93
  codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
94
  dim_embd=512,
95
  codebook_size=1024,
web-demos/replicate/predict.py CHANGED
@@ -14,12 +14,13 @@ try:
14
  except Exception:
15
  print('please install cog package')
16
 
17
- from basicsr.utils import imwrite, img2tensor, tensor2img
18
  from basicsr.archs.rrdbnet_arch import RRDBNet
 
19
  from basicsr.utils.realesrgan_utils import RealESRGANer
 
20
  from basicsr.utils.registry import ARCH_REGISTRY
21
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
22
 
 
23
 
24
  class Predictor(BasePredictor):
25
  def setup(self):
@@ -159,7 +160,8 @@ def imread(img_path):
159
 
160
 
161
  def set_realesrgan():
162
- if not torch.cuda.is_available(): # CPU
 
163
  import warnings
164
 
165
  warnings.warn(
 
14
  except Exception:
15
  print('please install cog package')
16
 
 
17
  from basicsr.archs.rrdbnet_arch import RRDBNet
18
+ from basicsr.utils import imwrite, img2tensor, tensor2img
19
  from basicsr.utils.realesrgan_utils import RealESRGANer
20
+ from basicsr.utils.misc import gpu_is_available
21
  from basicsr.utils.registry import ARCH_REGISTRY
 
22
 
23
+ from facelib.utils.face_restoration_helper import FaceRestoreHelper
24
 
25
  class Predictor(BasePredictor):
26
  def setup(self):
 
160
 
161
 
162
  def set_realesrgan():
163
+ # if not torch.cuda.is_available(): # CPU
164
+ if not gpu_is_available(): # CPU
165
  import warnings
166
 
167
  warnings.warn(