from __future__ import print_function import os import tarfile import requests from warnings import warn from zipfile import ZipFile from bs4 import BeautifulSoup from os.path import abspath, isdir, join, basename class GetData(object): """A Python script for downloading CycleGAN or pix2pix datasets. Parameters: technique (str) -- One of: 'cyclegan' or 'pix2pix'. verbose (bool) -- If True, print additional information. Examples: >>> from util.get_data import GetData >>> gd = GetData(technique='cyclegan') >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' and 'scripts/download_cyclegan_model.sh'. """ def __init__(self, technique='cyclegan', verbose=True): url_dict = { 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' } self.url = url_dict.get(technique.lower()) self._verbose = verbose def _print(self, text): if self._verbose: print(text) @staticmethod def _get_options(r): soup = BeautifulSoup(r.text, 'lxml') options = [h.text for h in soup.find_all('a', href=True) if h.text.endswith(('.zip', 'tar.gz'))] return options def _present_options(self): r = requests.get(self.url) options = self._get_options(r) print('Options:\n') for i, o in enumerate(options): print("{0}: {1}".format(i, o)) choice = input("\nPlease enter the number of the " "dataset above you wish to download:") return options[int(choice)] def _download_data(self, dataset_url, save_path): if not isdir(save_path): os.makedirs(save_path) base = basename(dataset_url) temp_save_path = join(save_path, base) with open(temp_save_path, "wb") as f: r = requests.get(dataset_url) f.write(r.content) if base.endswith('.tar.gz'): obj = tarfile.open(temp_save_path) elif base.endswith('.zip'): obj = ZipFile(temp_save_path, 'r') else: raise ValueError("Unknown File Type: {0}.".format(base)) self._print("Unpacking Data...") obj.extractall(save_path) obj.close() os.remove(temp_save_path) def get(self, save_path, dataset=None): """ Download a dataset. Parameters: save_path (str) -- A directory to save the data to. dataset (str) -- (optional). A specific dataset to download. Note: this must include the file extension. If None, options will be presented for you to choose from. Returns: save_path_full (str) -- the absolute path to the downloaded data. """ if dataset is None: selected_dataset = self._present_options() else: selected_dataset = dataset save_path_full = join(save_path, selected_dataset.split('.')[0]) if isdir(save_path_full): warn("\n'{0}' already exists. Voiding Download.".format( save_path_full)) else: self._print('Downloading Data...') url = "{0}/{1}".format(self.url, selected_dataset) self._download_data(url, save_path=save_path) return abspath(save_path_full)