|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import pickle
|
|
import re
|
|
import shutil
|
|
import tarfile
|
|
import urllib
|
|
import urllib.error
|
|
import urllib.request
|
|
from typing import Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import yaml
|
|
from iopath.common.download import download
|
|
from iopath.common.file_io import file_lock, g_pathmgr
|
|
from lavis.common.dist_utils import download_cached_file
|
|
from lavis.common.registry import registry
|
|
from torch.utils.model_zoo import tqdm
|
|
from torchvision.datasets.utils import (
|
|
check_integrity,
|
|
download_file_from_google_drive,
|
|
extract_archive,
|
|
)
|
|
|
|
|
|
def now():
|
|
from datetime import datetime
|
|
|
|
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
|
|
|
|
|
|
def is_url(url_or_filename):
|
|
parsed = urlparse(url_or_filename)
|
|
return parsed.scheme in ("http", "https")
|
|
|
|
|
|
def get_cache_path(rel_path):
|
|
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
|
|
|
|
|
|
def get_abs_path(rel_path):
|
|
return os.path.join(registry.get_path("library_root"), rel_path)
|
|
|
|
|
|
def load_json(filename):
|
|
with open(filename, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def makedir(dir_path):
|
|
"""
|
|
Create the directory if it does not exist.
|
|
"""
|
|
is_success = False
|
|
try:
|
|
if not g_pathmgr.exists(dir_path):
|
|
g_pathmgr.mkdirs(dir_path)
|
|
is_success = True
|
|
except BaseException:
|
|
print(f"Error creating directory: {dir_path}")
|
|
return is_success
|
|
|
|
|
|
def get_redirected_url(url: str):
|
|
"""
|
|
Given a URL, returns the URL it redirects to or the
|
|
original URL in case of no indirection
|
|
"""
|
|
import requests
|
|
|
|
with requests.Session() as session:
|
|
with session.get(url, stream=True, allow_redirects=True) as response:
|
|
if response.history:
|
|
return response.url
|
|
else:
|
|
return url
|
|
|
|
|
|
def to_google_drive_download_url(view_url: str) -> str:
|
|
"""
|
|
Utility function to transform a view URL of google drive
|
|
to a download URL for google drive
|
|
Example input:
|
|
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
|
|
Example output:
|
|
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
|
|
"""
|
|
splits = view_url.split("/")
|
|
assert splits[-1] == "view"
|
|
file_id = splits[-2]
|
|
return f"https://drive.google.com/uc?export=download&id={file_id}"
|
|
|
|
|
|
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
|
|
"""
|
|
Download a file from google drive
|
|
Downloading an URL from google drive requires confirmation when
|
|
the file of the size is too big (google drive notifies that
|
|
anti-viral checks cannot be performed on such files)
|
|
"""
|
|
import requests
|
|
|
|
with requests.Session() as session:
|
|
|
|
|
|
with session.get(url, stream=True, allow_redirects=True) as response:
|
|
for k, v in response.cookies.items():
|
|
if k.startswith("download_warning"):
|
|
url = url + "&confirm=" + v
|
|
|
|
|
|
with session.get(url, stream=True, verify=True) as response:
|
|
makedir(output_path)
|
|
path = os.path.join(output_path, output_file_name)
|
|
total_size = int(response.headers.get("Content-length", 0))
|
|
with open(path, "wb") as file:
|
|
from tqdm import tqdm
|
|
|
|
with tqdm(total=total_size) as progress_bar:
|
|
for block in response.iter_content(
|
|
chunk_size=io.DEFAULT_BUFFER_SIZE
|
|
):
|
|
file.write(block)
|
|
progress_bar.update(len(block))
|
|
|
|
|
|
def _get_google_drive_file_id(url: str) -> Optional[str]:
|
|
parts = urlparse(url)
|
|
|
|
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
|
return None
|
|
|
|
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
|
if match is None:
|
|
return None
|
|
|
|
return match.group("id")
|
|
|
|
|
|
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
|
|
with open(filename, "wb") as fh:
|
|
with urllib.request.urlopen(
|
|
urllib.request.Request(url, headers={"User-Agent": "vissl"})
|
|
) as response:
|
|
with tqdm(total=response.length) as pbar:
|
|
for chunk in iter(lambda: response.read(chunk_size), ""):
|
|
if not chunk:
|
|
break
|
|
pbar.update(chunk_size)
|
|
fh.write(chunk)
|
|
|
|
|
|
def download_url(
|
|
url: str,
|
|
root: str,
|
|
filename: Optional[str] = None,
|
|
md5: Optional[str] = None,
|
|
) -> None:
|
|
"""Download a file from a url and place it in root.
|
|
Args:
|
|
url (str): URL to download file from
|
|
root (str): Directory to place downloaded file in
|
|
filename (str, optional): Name to save the file under.
|
|
If None, use the basename of the URL.
|
|
md5 (str, optional): MD5 checksum of the download. If None, do not check
|
|
"""
|
|
root = os.path.expanduser(root)
|
|
if not filename:
|
|
filename = os.path.basename(url)
|
|
fpath = os.path.join(root, filename)
|
|
|
|
makedir(root)
|
|
|
|
|
|
if check_integrity(fpath, md5):
|
|
print("Using downloaded and verified file: " + fpath)
|
|
return
|
|
|
|
|
|
url = get_redirected_url(url)
|
|
|
|
|
|
file_id = _get_google_drive_file_id(url)
|
|
if file_id is not None:
|
|
return download_file_from_google_drive(file_id, root, filename, md5)
|
|
|
|
|
|
try:
|
|
print("Downloading " + url + " to " + fpath)
|
|
_urlretrieve(url, fpath)
|
|
except (urllib.error.URLError, IOError) as e:
|
|
if url[:5] == "https":
|
|
url = url.replace("https:", "http:")
|
|
print(
|
|
"Failed download. Trying https -> http instead."
|
|
" Downloading " + url + " to " + fpath
|
|
)
|
|
_urlretrieve(url, fpath)
|
|
else:
|
|
raise e
|
|
|
|
|
|
if not check_integrity(fpath, md5):
|
|
raise RuntimeError("File not found or corrupted.")
|
|
|
|
|
|
def download_and_extract_archive(
|
|
url: str,
|
|
download_root: str,
|
|
extract_root: Optional[str] = None,
|
|
filename: Optional[str] = None,
|
|
md5: Optional[str] = None,
|
|
remove_finished: bool = False,
|
|
) -> None:
|
|
download_root = os.path.expanduser(download_root)
|
|
if extract_root is None:
|
|
extract_root = download_root
|
|
if not filename:
|
|
filename = os.path.basename(url)
|
|
|
|
download_url(url, download_root, filename, md5)
|
|
|
|
archive = os.path.join(download_root, filename)
|
|
print("Extracting {} to {}".format(archive, extract_root))
|
|
extract_archive(archive, extract_root, remove_finished)
|
|
|
|
|
|
def cache_url(url: str, cache_dir: str) -> str:
|
|
"""
|
|
This implementation downloads the remote resource and caches it locally.
|
|
The resource will only be downloaded if not previously requested.
|
|
"""
|
|
parsed_url = urlparse(url)
|
|
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
|
|
makedir(dirname)
|
|
filename = url.split("/")[-1]
|
|
cached = os.path.join(dirname, filename)
|
|
with file_lock(cached):
|
|
if not os.path.isfile(cached):
|
|
logging.info(f"Downloading {url} to {cached} ...")
|
|
cached = download(url, dirname, filename=filename)
|
|
logging.info(f"URL {url} cached in {cached}")
|
|
return cached
|
|
|
|
|
|
|
|
def create_file_symlink(file1, file2):
|
|
"""
|
|
Simply create the symlinks for a given file1 to file2.
|
|
Useful during model checkpointing to symlinks to the
|
|
latest successful checkpoint.
|
|
"""
|
|
try:
|
|
if g_pathmgr.exists(file2):
|
|
g_pathmgr.rm(file2)
|
|
g_pathmgr.symlink(file1, file2)
|
|
except Exception as e:
|
|
logging.info(f"Could NOT create symlink. Error: {e}")
|
|
|
|
|
|
def save_file(data, filename, append_to_json=True, verbose=True):
|
|
"""
|
|
Common i/o utility to handle saving data to various file formats.
|
|
Supported:
|
|
.pkl, .pickle, .npy, .json
|
|
Specifically for .json, users have the option to either append (default)
|
|
or rewrite by passing in Boolean value to append_to_json.
|
|
"""
|
|
if verbose:
|
|
logging.info(f"Saving data to file: {filename}")
|
|
file_ext = os.path.splitext(filename)[1]
|
|
if file_ext in [".pkl", ".pickle"]:
|
|
with g_pathmgr.open(filename, "wb") as fopen:
|
|
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
|
|
elif file_ext == ".npy":
|
|
with g_pathmgr.open(filename, "wb") as fopen:
|
|
np.save(fopen, data)
|
|
elif file_ext == ".json":
|
|
if append_to_json:
|
|
with g_pathmgr.open(filename, "a") as fopen:
|
|
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
|
fopen.flush()
|
|
else:
|
|
with g_pathmgr.open(filename, "w") as fopen:
|
|
fopen.write(json.dumps(data, sort_keys=True) + "\n")
|
|
fopen.flush()
|
|
elif file_ext == ".yaml":
|
|
with g_pathmgr.open(filename, "w") as fopen:
|
|
dump = yaml.dump(data)
|
|
fopen.write(dump)
|
|
fopen.flush()
|
|
else:
|
|
raise Exception(f"Saving {file_ext} is not supported yet")
|
|
|
|
if verbose:
|
|
logging.info(f"Saved data to file: {filename}")
|
|
|
|
|
|
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
|
|
"""
|
|
Common i/o utility to handle loading data from various file formats.
|
|
Supported:
|
|
.pkl, .pickle, .npy, .json
|
|
For the npy files, we support reading the files in mmap_mode.
|
|
If the mmap_mode of reading is not successful, we load data without the
|
|
mmap_mode.
|
|
"""
|
|
if verbose:
|
|
logging.info(f"Loading data from file: {filename}")
|
|
|
|
file_ext = os.path.splitext(filename)[1]
|
|
if file_ext == ".txt":
|
|
with g_pathmgr.open(filename, "r") as fopen:
|
|
data = fopen.readlines()
|
|
elif file_ext in [".pkl", ".pickle"]:
|
|
with g_pathmgr.open(filename, "rb") as fopen:
|
|
data = pickle.load(fopen, encoding="latin1")
|
|
elif file_ext == ".npy":
|
|
if mmap_mode:
|
|
try:
|
|
with g_pathmgr.open(filename, "rb") as fopen:
|
|
data = np.load(
|
|
fopen,
|
|
allow_pickle=allow_pickle,
|
|
encoding="latin1",
|
|
mmap_mode=mmap_mode,
|
|
)
|
|
except ValueError as e:
|
|
logging.info(
|
|
f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
|
|
)
|
|
data = np.load(
|
|
filename,
|
|
allow_pickle=allow_pickle,
|
|
encoding="latin1",
|
|
mmap_mode=mmap_mode,
|
|
)
|
|
logging.info("Successfully loaded without g_pathmgr")
|
|
except Exception:
|
|
logging.info("Could not mmap without g_pathmgr. Trying without mmap")
|
|
with g_pathmgr.open(filename, "rb") as fopen:
|
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
|
else:
|
|
with g_pathmgr.open(filename, "rb") as fopen:
|
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
|
|
elif file_ext == ".json":
|
|
with g_pathmgr.open(filename, "r") as fopen:
|
|
data = json.load(fopen)
|
|
elif file_ext == ".yaml":
|
|
with g_pathmgr.open(filename, "r") as fopen:
|
|
data = yaml.load(fopen, Loader=yaml.FullLoader)
|
|
elif file_ext == ".csv":
|
|
with g_pathmgr.open(filename, "r") as fopen:
|
|
data = pd.read_csv(fopen)
|
|
else:
|
|
raise Exception(f"Reading from {file_ext} is not supported yet")
|
|
return data
|
|
|
|
|
|
def abspath(resource_path: str):
|
|
"""
|
|
Make a path absolute, but take into account prefixes like
|
|
"http://" or "manifold://"
|
|
"""
|
|
regex = re.compile(r"^\w+://")
|
|
if regex.match(resource_path) is None:
|
|
return os.path.abspath(resource_path)
|
|
else:
|
|
return resource_path
|
|
|
|
|
|
def makedir(dir_path):
|
|
"""
|
|
Create the directory if it does not exist.
|
|
"""
|
|
is_success = False
|
|
try:
|
|
if not g_pathmgr.exists(dir_path):
|
|
g_pathmgr.mkdirs(dir_path)
|
|
is_success = True
|
|
except BaseException:
|
|
logging.info(f"Error creating directory: {dir_path}")
|
|
return is_success
|
|
|
|
|
|
def is_url(input_url):
|
|
"""
|
|
Check if an input string is a url. look for http(s):// and ignoring the case
|
|
"""
|
|
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
|
|
return is_url
|
|
|
|
|
|
def download_and_untar(url):
|
|
cached_file = download_cached_file(
|
|
url, check_hash=False, progress=True
|
|
)
|
|
|
|
untarred_dir = os.path.basename(url).split(".")[0]
|
|
parent_dir = os.path.dirname(cached_file)
|
|
|
|
full_dir = os.path.join(parent_dir, untarred_dir)
|
|
|
|
if not os.path.exists(full_dir):
|
|
with tarfile.open(cached_file) as tar:
|
|
tar.extractall(parent_dir)
|
|
|
|
return full_dir
|
|
|
|
def cleanup_dir(dir):
|
|
"""
|
|
Utility for deleting a directory. Useful for cleaning the storage space
|
|
that contains various training artifacts like checkpoints, data etc.
|
|
"""
|
|
if os.path.exists(dir):
|
|
logging.info(f"Deleting directory: {dir}")
|
|
shutil.rmtree(dir)
|
|
logging.info(f"Deleted contents of directory: {dir}")
|
|
|
|
|
|
def get_file_size(filename):
|
|
"""
|
|
Given a file, get the size of file in MB
|
|
"""
|
|
size_in_mb = os.path.getsize(filename) / float(1024**2)
|
|
return size_in_mb
|
|
|