danielsapit commited on
Commit
7b0d035
1 Parent(s): 3cf845b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -9,6 +9,14 @@ import utils_image as util
9
  from network_fbcnn import FBCNN as net
10
  import requests
11
 
 
 
 
 
 
 
 
 
12
 
13
  def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_shift, state):
14
 
@@ -27,13 +35,14 @@ def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_s
27
  #model_pool = '/FBCNN/model_zoo' # fixed
28
  #model_path = os.path.join(model_pool, model_name)
29
  model_path = model_name
 
30
  if os.path.exists(model_path):
31
  print(f'loading model from {model_path}')
32
  else:
33
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
34
  url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
35
  r = requests.get(url, allow_redirects=True)
36
- open(model_path, 'wb').write(r.content)
37
 
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
 
9
  from network_fbcnn import FBCNN as net
10
  import requests
11
 
12
+ for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
13
+ if os.path.exists(path):
14
+ print(f'loading model from {model_path}')
15
+ else:
16
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
17
+ url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
18
+ r = requests.get(url, allow_redirects=True)
19
+ open(model_path, 'wb').write(r.content)
20
 
21
  def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_shift, state):
22
 
 
35
  #model_pool = '/FBCNN/model_zoo' # fixed
36
  #model_path = os.path.join(model_pool, model_name)
37
  model_path = model_name
38
+ """
39
  if os.path.exists(model_path):
40
  print(f'loading model from {model_path}')
41
  else:
42
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
43
  url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
44
  r = requests.get(url, allow_redirects=True)
45
+ open(model_path, 'wb').write(r.content)"""
46
 
47
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48