victorisgeek commited on
Commit
98b2f02
1 Parent(s): bcea725

Upload 4 files

Browse files
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .download import download_file
2
+ from .weights_urls import get_model_url
utils/download.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import zipfile
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ def download_file(url: str, save_dir='./', overwrite=False, unzip=True):
9
+ os.makedirs(save_dir, exist_ok=True)
10
+ file_name = url.split('/')[-1]
11
+ file_path = os.path.join(save_dir, file_name)
12
+
13
+ if os.path.exists(file_path) and not overwrite:
14
+ pass
15
+ else:
16
+ print('Downloading file {} from {}...'.format(file_path, url))
17
+
18
+ r = requests.get(url, stream=True)
19
+ print(r.status_code)
20
+ if r.status_code != 200:
21
+ raise RuntimeError('Failed downloading url {}!'.format(url))
22
+ total_length = r.headers.get('content-length')
23
+ with open(file_path, 'wb') as f:
24
+ if total_length is None: # no content length header
25
+ for chunk in r.iter_content(chunk_size=1024):
26
+ if chunk: # filter out keep-alive new chunks
27
+ f.write(chunk)
28
+ else:
29
+ total_length = int(total_length)
30
+ print('file length: ', int(total_length / 1024. + 0.5))
31
+ for chunk in tqdm(r.iter_content(chunk_size=1024),
32
+ total=int(total_length / 1024. + 0.5),
33
+ unit='KB',
34
+ unit_scale=False,
35
+ dynamic_ncols=True):
36
+ f.write(chunk)
37
+ if unzip and file_path.endswith('.zip'):
38
+ save_dir = file_path.split('.')[0]
39
+ if os.path.isdir(save_dir) and os.path.exists(save_dir):
40
+ pass
41
+ else:
42
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
43
+ zip_ref.extractall(save_dir)
44
+
45
+ return save_dir, file_path
utils/utils.py ADDED
File without changes
utils/weights_urls.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WEIGHT_URLS = {
2
+ 'buffalo_l':
3
+ 'https://github.com/justld/dofaker/releases/download/v0.1/buffalo_l.zip',
4
+ 'buffalo_s':
5
+ 'https://github.com/justld/dofaker/releases/download/v0.1/buffalo_s.zip',
6
+ 'buffalo_sc':
7
+ 'https://github.com/justld/dofaker/releases/download/v0.1/buffalo_sc.zip',
8
+ 'inswapper':
9
+ 'https://github.com/justld/dofaker/releases/download/v0.1/inswapper_128.onnx',
10
+ 'gfpgan':
11
+ 'https://github.com/justld/dofaker/releases/download/v0.1/GFPGANv1.3.onnx',
12
+ 'bsrgan':
13
+ 'https://github.com/justld/dofaker/releases/download/v0.1/bsrgan_4.onnx',
14
+ 'openpose_body':
15
+ 'https://github.com/justld/dofaker/releases/download/v0.1/openpose_body.onnx',
16
+ 'pose_transfer':
17
+ 'https://github.com/justld/dofaker/releases/download/v0.1/pose_transfer.onnx',
18
+ }
19
+
20
+
21
+ def get_model_url(model_name):
22
+ return WEIGHT_URLS[model_name]