|
|
|
|
|
import tarfile |
|
from pathlib import Path |
|
from warnings import warn |
|
from zipfile import ZipFile |
|
|
|
import requests |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
class GetData(object): |
|
"""A Python script for downloading CycleGAN or pix2pix datasets. |
|
|
|
Parameters: |
|
technique (str) -- One of: 'cyclegan' or 'pix2pix'. |
|
verbose (bool) -- If True, print additional information. |
|
|
|
Examples: |
|
>>> from util.get_data import GetData |
|
>>> gd = GetData(technique='cyclegan', save_path='./datasets')# options will be displayed. |
|
|
|
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' |
|
and 'scripts/download_cyclegan_model.sh'. |
|
""" |
|
|
|
def __init__(self, technique="CycleGAN", save_path="./datasets", verbose=True): |
|
url_dict = { |
|
"cyclegan": "http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/", |
|
} |
|
self.url = url_dict.get(technique.lower()) |
|
self._verbose = verbose |
|
self.get(save_path=save_path) |
|
|
|
def _print(self, text: str): |
|
if self._verbose: |
|
print(text) |
|
|
|
@staticmethod |
|
def _get_options(r): |
|
soup = BeautifulSoup(r.text, "lxml") |
|
options = [ |
|
h.text |
|
for h in soup.find_all("a", href=True) |
|
if h.text.endswith((".zip", "tar.gz")) |
|
] |
|
return options |
|
|
|
def _present_options(self): |
|
print(self.url) |
|
r = requests.get(self.url) |
|
options = self._get_options(r) |
|
print("Options:\n") |
|
for i, o in enumerate(options): |
|
print("{0}: {1}".format(i, o)) |
|
choice = input( |
|
"\nPlease enter the number of the " "dataset above you wish to download:" |
|
) |
|
return options[int(choice)] |
|
|
|
def _download_data(self, dataset_url: str, dataset_path: Path): |
|
dataset_path.mkdir(exist_ok=True) |
|
|
|
save_path = Path(dataset_path).joinpath(Path(dataset_url).name) |
|
|
|
print(dataset_url) |
|
import urllib.request |
|
|
|
urllib.request.urlretrieve(dataset_url, save_path) |
|
print("--> 下载完成 ") |
|
|
|
if save_path.endswith(".tar.gz"): |
|
obj = tarfile.open(save_path) |
|
elif save_path.endswith(".zip"): |
|
obj = ZipFile(save_path, "r") |
|
else: |
|
raise ValueError("Unknown File Type: {0}.".format(save_path)) |
|
self._print("Unpacking Data...") |
|
obj.extractall(save_path) |
|
obj.close() |
|
|
|
def get(self, save_path: str, dataset=None): |
|
save_path_ = Path(save_path) |
|
if dataset is None: |
|
selected_dataset = self._present_options() |
|
else: |
|
selected_dataset = dataset |
|
save_path_full = save_path_.joinpath(selected_dataset.split(".")[0]) |
|
print(save_path_full) |
|
|
|
if save_path_full.is_dir(): |
|
warn("\n'{0}' already exists.".format(save_path_full)) |
|
else: |
|
self._print("Downloading Data...") |
|
url = "{0}/{1}".format(self.url, selected_dataset) |
|
self._download_data(url, save_path=save_path) |
|
return Path(save_path_full) |
|
|