Xintao commited on
Commit
201489f
1 Parent(s): 8f15d2c

force use cuda

Browse files
Files changed (2) hide show
  1. gfpgan_utils.py +1 -1
  2. realesrgan_utils.py +4 -3
gfpgan_utils.py CHANGED
@@ -33,7 +33,7 @@ class GFPGANer():
33
  self.bg_upsampler = bg_upsampler
34
 
35
  # initialize model
36
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
37
  # initialize the GFP-GAN
38
  self.gfpgan = GFPGANv1Clean(
39
  out_size=512,
 
33
  self.bg_upsampler = bg_upsampler
34
 
35
  # initialize model
36
+ self.device = torch.device('cuda') if device is None else device
37
  # initialize the GFP-GAN
38
  self.gfpgan = GFPGANv1Clean(
39
  out_size=512,
realesrgan_utils.py CHANGED
@@ -1,9 +1,10 @@
1
- import cv2
2
  import math
3
- import numpy as np
4
  import os
5
  import queue
6
  import threading
 
 
 
7
  import torch
8
  from basicsr.utils.download_util import load_file_from_url
9
  from torch.nn import functional as F
@@ -35,7 +36,7 @@ class RealESRGANer():
35
  self.half = half
36
 
37
  # initialize model
38
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
  # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
40
  if model_path.startswith('https://'):
41
  model_path = load_file_from_url(
 
 
1
  import math
 
2
  import os
3
  import queue
4
  import threading
5
+
6
+ import cv2
7
+ import numpy as np
8
  import torch
9
  from basicsr.utils.download_util import load_file_from_url
10
  from torch.nn import functional as F
 
36
  self.half = half
37
 
38
  # initialize model
39
+ self.device = torch.device('cuda')
40
  # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
41
  if model_path.startswith('https://'):
42
  model_path = load_file_from_url(