Mountchicken's picture
Upload 704 files
9bf4bd7
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import shutil
import ssl
import urllib.request as request
from typing import Dict, List, Optional, Tuple
from mmengine import mkdir_or_exist
from mmocr.registry import DATA_OBTAINERS
from mmocr.utils import check_integrity, is_archive
ssl._create_default_https_context = ssl._create_unverified_context
@DATA_OBTAINERS.register_module()
class NaiveDataObtainer:
"""A naive pipeline for obtaining dataset.
download -> extract -> move
Args:
files (list[dict]): A list of file information.
cache_path (str): The path to cache the downloaded files.
data_root (str): The root path of the dataset. It is usually set auto-
matically and users do not need to set it manually in config file
in most cases.
task (str): The task of the dataset. It is usually set automatically
and users do not need to set it manually in config file
in most cases.
"""
def __init__(self, files: List[Dict], cache_path: str, data_root: str,
task: str) -> None:
self.files = files
self.cache_path = cache_path
self.data_root = data_root
self.task = task
mkdir_or_exist(self.data_root)
mkdir_or_exist(osp.join(self.data_root, f'{task}_imgs'))
mkdir_or_exist(osp.join(self.data_root, 'annotations'))
mkdir_or_exist(self.cache_path)
def __call__(self):
for file in self.files:
save_name = file.get('save_name', None)
url = file.get('url', None)
md5 = file.get('md5', None)
download_path = osp.join(
self.cache_path,
osp.basename(url) if save_name is None else save_name)
# Download required files
if not check_integrity(download_path, md5):
self.download(url=url, dst_path=download_path)
# Extract downloaded zip files to data root
self.extract(src_path=download_path, dst_path=self.data_root)
# Move & Rename dataset files
if 'mapping' in file:
self.move(mapping=file['mapping'])
self.clean()
def download(self, url: Optional[str], dst_path: str) -> None:
"""Download file from given url with progress bar.
Args:
url (str): The url to download the file.
dst_path (str): The destination path to save the file.
"""
def progress(down: float, block: float, size: float) -> None:
"""Show download progress.
Args:
down (float): Downloaded size.
block (float): Block size.
size (float): Total size of the file.
"""
percent = min(100. * down * block / size, 100)
file_name = osp.basename(dst_path)
print(f'\rDownloading {file_name}: {percent:.2f}%', end='')
if url is None and not osp.exists(dst_path):
raise FileNotFoundError(
'Direct url is not available for this dataset.'
' Please manually download the required files'
' following the guides.')
if url.startswith('magnet'):
raise NotImplementedError('Please use any BitTorrent client to '
'download the following magnet link to '
f'{osp.abspath(dst_path)} and '
f'try again.\nLink: {url}')
print('Downloading...')
print(f'URL: {url}')
print(f'Destination: {osp.abspath(dst_path)}')
print('If you stuck here for a long time, please check your network, '
'or manually download the file to the destination path and '
'run the script again.')
request.urlretrieve(url, dst_path, progress)
print('')
def extract(self,
src_path: str,
dst_path: str,
delete: bool = False) -> None:
"""Extract zip/tar.gz files.
Args:
src_path (str): Path to the zip file.
dst_path (str): Path to the destination folder.
delete (bool, optional): Whether to delete the zip file. Defaults
to False.
"""
if not is_archive(src_path):
# Copy the file to the destination folder if it is not a zip
if osp.isfile(src_path):
shutil.copy(src_path, dst_path)
else:
shutil.copytree(src_path, dst_path)
return
zip_name = osp.basename(src_path).split('.')[0]
if dst_path is None:
dst_path = osp.join(osp.dirname(src_path), zip_name)
else:
dst_path = osp.join(dst_path, zip_name)
extracted = False
if osp.exists(dst_path):
name = set(os.listdir(dst_path))
if '.finish' in name:
extracted = True
elif '.finish' not in name and len(name) > 0:
while True:
c = input(f'{dst_path} already exists when extracting '
'{zip_name}, unzip again? (y/N) ') or 'N'
if c.lower() in ['y', 'n']:
extracted = c == 'n'
break
if extracted:
open(osp.join(dst_path, '.finish'), 'w').close()
print(f'{zip_name} has been extracted. Skip')
return
mkdir_or_exist(dst_path)
print(f'Extracting: {osp.basename(src_path)}')
if src_path.endswith('.zip'):
try:
import zipfile
except ImportError:
raise ImportError(
'Please install zipfile by running "pip install zipfile".')
with zipfile.ZipFile(src_path, 'r') as zip_ref:
zip_ref.extractall(dst_path)
elif src_path.endswith('.tar.gz') or src_path.endswith('.tar'):
if src_path.endswith('.tar.gz'):
mode = 'r:gz'
elif src_path.endswith('.tar'):
mode = 'r:'
try:
import tarfile
except ImportError:
raise ImportError(
'Please install tarfile by running "pip install tarfile".')
with tarfile.open(src_path, mode) as tar_ref:
tar_ref.extractall(dst_path)
open(osp.join(dst_path, '.finish'), 'w').close()
if delete:
os.remove(src_path)
def move(self, mapping: List[Tuple[str, str]]) -> None:
"""Rename and move dataset files one by one.
Args:
mapping (List[Tuple[str, str]]): A list of tuples, each
tuple contains the source file name and the destination file name.
"""
for src, dst in mapping:
src = osp.join(self.data_root, src)
dst = osp.join(self.data_root, dst)
if '*' in src:
mkdir_or_exist(dst)
for f in glob.glob(src):
if not osp.exists(
osp.join(dst, osp.relpath(f, self.data_root))):
shutil.move(f, dst)
elif osp.exists(src) and not osp.exists(dst):
mkdir_or_exist(osp.dirname(dst))
shutil.move(src, dst)
def clean(self) -> None:
"""Remove empty dirs."""
for root, dirs, files in os.walk(self.data_root, topdown=False):
if not files and not dirs:
os.rmdir(root)