miguelmuzo commited on
Commit
473c902
·
verified ·
1 Parent(s): c5a8345

Create drive.py

Browse files
Files changed (1) hide show
  1. utils/drive.py +110 -0
utils/drive.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # URL helpers, see https://github.com/NVlabs/stylegan
2
+ # ------------------------------------------------------------------------------------------
3
+
4
+ import requests
5
+ import html
6
+ import hashlib
7
+ import gdown
8
+ import glob
9
+ import os
10
+ import io
11
+ from typing import Any
12
+ import re
13
+ import uuid
14
+
15
+ weight_dic = {'afhqwild.pt': 'https://drive.google.com/file/d/14OnzO4QWaAytKXVqcfWo_o2MzoR4ygnr/view?usp=sharing',
16
+ 'afhqdog.pt': 'https://drive.google.com/file/d/16v6jPtKVlvq8rg2Sdi3-R9qZEVDgvvEA/view?usp=sharing',
17
+ 'afhqcat.pt': 'https://drive.google.com/file/d/1HXLER5R3EMI8DSYDBZafoqpX4EtyOf2R/view?usp=sharing',
18
+ 'ffhq.pt': 'https://drive.google.com/file/d/1AT6bNR2ppK8f2ETL_evT27f3R_oyWNHS/view?usp=sharing',
19
+ 'metfaces.pt': 'https://drive.google.com/file/d/16wM2PwVWzaMsRgPExvRGsq6BWw_muKbf/view?usp=sharing',
20
+ 'seg.pth': 'https://drive.google.com/file/d/1lIKvQaFKHT5zC7uS4p17O9ZpfwmwlS62/view?usp=sharing'}
21
+
22
+
23
+ def download_weight(weight_path):
24
+ gdown.download(weight_dic[os.path.basename(weight_path)],
25
+ output=weight_path, fuzzy=True)
26
+
27
+
28
+ def is_url(obj: Any) -> bool:
29
+ """Determine whether the given object is a valid URL string."""
30
+ if not isinstance(obj, str) or not "://" in obj:
31
+ return False
32
+ try:
33
+ res = requests.compat.urlparse(obj)
34
+ if not res.scheme or not res.netloc or not "." in res.netloc:
35
+ return False
36
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
37
+ if not res.scheme or not res.netloc or not "." in res.netloc:
38
+ return False
39
+ except:
40
+ return False
41
+ return True
42
+
43
+
44
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True,
45
+ return_path: bool = False) -> Any:
46
+ """Download the given URL and return a binary-mode file object to access the data."""
47
+ assert is_url(url)
48
+ assert num_attempts >= 1
49
+
50
+ # Lookup from cache.
51
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
52
+ if cache_dir is not None:
53
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
54
+ if len(cache_files) == 1:
55
+ if (return_path):
56
+ return cache_files[0]
57
+ else:
58
+ return open(cache_files[0], "rb")
59
+
60
+ # Download.
61
+ url_name = None
62
+ url_data = None
63
+ with requests.Session() as session:
64
+ if verbose:
65
+ print("Downloading %s ..." % url, end="", flush=True)
66
+ for attempts_left in reversed(range(num_attempts)):
67
+ try:
68
+ with session.get(url) as res:
69
+ res.raise_for_status()
70
+ if len(res.content) == 0:
71
+ raise IOError("No data received")
72
+
73
+ if len(res.content) < 8192:
74
+ content_str = res.content.decode("utf-8")
75
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
76
+ links = [html.unescape(link) for link in content_str.split('"') if
77
+ "export=download" in link]
78
+ if len(links) == 1:
79
+ url = requests.compat.urljoin(url, links[0])
80
+ raise IOError("Google Drive virus checker nag")
81
+ if "Google Drive - Quota exceeded" in content_str:
82
+ raise IOError("Google Drive quota exceeded")
83
+
84
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
85
+ url_name = match[1] if match else url
86
+ url_data = res.content
87
+ if verbose:
88
+ print(" done")
89
+ break
90
+ except:
91
+ if not attempts_left:
92
+ if verbose:
93
+ print(" failed")
94
+ raise
95
+ if verbose:
96
+ print(".", end="", flush=True)
97
+
98
+ # Save to cache.
99
+ if cache_dir is not None:
100
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
101
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
102
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
103
+ os.makedirs(cache_dir, exist_ok=True)
104
+ with open(temp_file, "wb") as f:
105
+ f.write(url_data)
106
+ os.replace(temp_file, cache_file) # atomic
107
+ if (return_path): return cache_file
108
+
109
+ # Return data as file object.
110
+ return io.BytesIO(url_data)