CycleGAN / util /get_data.py
Yanguan's picture
0
58da73e
raw
history blame contribute delete
No virus
3.12 kB
# from __future__ import print_function
import tarfile
from pathlib import Path
from warnings import warn
from zipfile import ZipFile
import requests
from bs4 import BeautifulSoup
class GetData(object):
"""A Python script for downloading CycleGAN or pix2pix datasets.
Parameters:
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
verbose (bool) -- If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan', save_path='./datasets')# options will be displayed.
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
and 'scripts/download_cyclegan_model.sh'.
"""
def __init__(self, technique="CycleGAN", save_path="./datasets", verbose=True):
url_dict = {
"cyclegan": "http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/",
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
self.get(save_path=save_path)
def _print(self, text: str):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, "lxml")
options = [
h.text
for h in soup.find_all("a", href=True)
if h.text.endswith((".zip", "tar.gz"))
]
return options
def _present_options(self):
print(self.url)
r = requests.get(self.url)
options = self._get_options(r)
print("Options:\n")
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input(
"\nPlease enter the number of the " "dataset above you wish to download:"
)
return options[int(choice)]
def _download_data(self, dataset_url: str, dataset_path: Path):
dataset_path.mkdir(exist_ok=True)
save_path = Path(dataset_path).joinpath(Path(dataset_url).name)
print(dataset_url)
import urllib.request
urllib.request.urlretrieve(dataset_url, save_path)
print("--> 下载完成 ")
if save_path.endswith(".tar.gz"):
obj = tarfile.open(save_path)
elif save_path.endswith(".zip"):
obj = ZipFile(save_path, "r")
else:
raise ValueError("Unknown File Type: {0}.".format(save_path))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
def get(self, save_path: str, dataset=None):
save_path_ = Path(save_path)
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = save_path_.joinpath(selected_dataset.split(".")[0])
print(save_path_full)
if save_path_full.is_dir():
warn("\n'{0}' already exists.".format(save_path_full))
else:
self._print("Downloading Data...")
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return Path(save_path_full)