|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import io |
|
import json |
|
import logging |
|
import os |
|
import pickle |
|
import re |
|
import shutil |
|
import urllib |
|
import urllib.error |
|
import urllib.request |
|
from typing import Optional |
|
from urllib.parse import urlparse |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import yaml |
|
from iopath.common.download import download |
|
from iopath.common.file_io import file_lock, g_pathmgr |
|
from .registry import registry |
|
from torch.utils.model_zoo import tqdm |
|
from torchvision.datasets.utils import ( |
|
check_integrity, |
|
download_file_from_google_drive, |
|
extract_archive, |
|
) |
|
|
|
|
|
def now(): |
|
from datetime import datetime |
|
|
|
return datetime.now().strftime("%Y%m%d%H%M") |
|
|
|
|
|
def is_url(url_or_filename): |
|
parsed = urlparse(url_or_filename) |
|
return parsed.scheme in ("http", "https") |
|
|
|
|
|
def get_cache_path(rel_path): |
|
return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) |
|
|
|
|
|
def get_abs_path(rel_path): |
|
return os.path.join(registry.get_path("library_root"), rel_path) |
|
|
|
|
|
def load_json(filename): |
|
with open(filename, "r") as f: |
|
return json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def makedir(dir_path): |
|
""" |
|
Create the directory if it does not exist. |
|
""" |
|
is_success = False |
|
try: |
|
if not g_pathmgr.exists(dir_path): |
|
g_pathmgr.mkdirs(dir_path) |
|
is_success = True |
|
except BaseException: |
|
print(f"Error creating directory: {dir_path}") |
|
return is_success |
|
|
|
|
|
def get_redirected_url(url: str): |
|
""" |
|
Given a URL, returns the URL it redirects to or the |
|
original URL in case of no indirection |
|
""" |
|
import requests |
|
|
|
with requests.Session() as session: |
|
with session.get(url, stream=True, allow_redirects=True) as response: |
|
if response.history: |
|
return response.url |
|
else: |
|
return url |
|
|
|
|
|
def to_google_drive_download_url(view_url: str) -> str: |
|
""" |
|
Utility function to transform a view URL of google drive |
|
to a download URL for google drive |
|
Example input: |
|
https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view |
|
Example output: |
|
https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp |
|
""" |
|
splits = view_url.split("/") |
|
assert splits[-1] == "view" |
|
file_id = splits[-2] |
|
return f"https://drive.google.com/uc?export=download&id={file_id}" |
|
|
|
|
|
def download_google_drive_url(url: str, output_path: str, output_file_name: str): |
|
""" |
|
Download a file from google drive |
|
Downloading an URL from google drive requires confirmation when |
|
the file of the size is too big (google drive notifies that |
|
anti-viral checks cannot be performed on such files) |
|
""" |
|
import requests |
|
|
|
with requests.Session() as session: |
|
|
|
|
|
with session.get(url, stream=True, allow_redirects=True) as response: |
|
for k, v in response.cookies.items(): |
|
if k.startswith("download_warning"): |
|
url = url + "&confirm=" + v |
|
|
|
|
|
with session.get(url, stream=True, verify=True) as response: |
|
makedir(output_path) |
|
path = os.path.join(output_path, output_file_name) |
|
total_size = int(response.headers.get("Content-length", 0)) |
|
with open(path, "wb") as file: |
|
from tqdm import tqdm |
|
|
|
with tqdm(total=total_size) as progress_bar: |
|
for block in response.iter_content( |
|
chunk_size=io.DEFAULT_BUFFER_SIZE |
|
): |
|
file.write(block) |
|
progress_bar.update(len(block)) |
|
|
|
|
|
def _get_google_drive_file_id(url: str) -> Optional[str]: |
|
parts = urlparse(url) |
|
|
|
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: |
|
return None |
|
|
|
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path) |
|
if match is None: |
|
return None |
|
|
|
return match.group("id") |
|
|
|
|
|
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: |
|
with open(filename, "wb") as fh: |
|
with urllib.request.urlopen( |
|
urllib.request.Request(url, headers={"User-Agent": "vissl"}) |
|
) as response: |
|
with tqdm(total=response.length) as pbar: |
|
for chunk in iter(lambda: response.read(chunk_size), ""): |
|
if not chunk: |
|
break |
|
pbar.update(chunk_size) |
|
fh.write(chunk) |
|
|
|
|
|
def download_url( |
|
url: str, |
|
root: str, |
|
filename: Optional[str] = None, |
|
md5: Optional[str] = None, |
|
) -> None: |
|
"""Download a file from a url and place it in root. |
|
Args: |
|
url (str): URL to download file from |
|
root (str): Directory to place downloaded file in |
|
filename (str, optional): Name to save the file under. |
|
If None, use the basename of the URL. |
|
md5 (str, optional): MD5 checksum of the download. If None, do not check |
|
""" |
|
root = os.path.expanduser(root) |
|
if not filename: |
|
filename = os.path.basename(url) |
|
fpath = os.path.join(root, filename) |
|
|
|
makedir(root) |
|
|
|
|
|
if check_integrity(fpath, md5): |
|
print("Using downloaded and verified file: " + fpath) |
|
return |
|
|
|
|
|
url = get_redirected_url(url) |
|
|
|
|
|
file_id = _get_google_drive_file_id(url) |
|
if file_id is not None: |
|
return download_file_from_google_drive(file_id, root, filename, md5) |
|
|
|
|
|
try: |
|
print("Downloading " + url + " to " + fpath) |
|
_urlretrieve(url, fpath) |
|
except (urllib.error.URLError, IOError) as e: |
|
if url[:5] == "https": |
|
url = url.replace("https:", "http:") |
|
print( |
|
"Failed download. Trying https -> http instead." |
|
" Downloading " + url + " to " + fpath |
|
) |
|
_urlretrieve(url, fpath) |
|
else: |
|
raise e |
|
|
|
|
|
if not check_integrity(fpath, md5): |
|
raise RuntimeError("File not found or corrupted.") |
|
|
|
|
|
def download_and_extract_archive( |
|
url: str, |
|
download_root: str, |
|
extract_root: Optional[str] = None, |
|
filename: Optional[str] = None, |
|
md5: Optional[str] = None, |
|
remove_finished: bool = False, |
|
) -> None: |
|
download_root = os.path.expanduser(download_root) |
|
if extract_root is None: |
|
extract_root = download_root |
|
if not filename: |
|
filename = os.path.basename(url) |
|
|
|
download_url(url, download_root, filename, md5) |
|
|
|
archive = os.path.join(download_root, filename) |
|
print("Extracting {} to {}".format(archive, extract_root)) |
|
extract_archive(archive, extract_root, remove_finished) |
|
|
|
|
|
def cache_url(url: str, cache_dir: str) -> str: |
|
""" |
|
This implementation downloads the remote resource and caches it locally. |
|
The resource will only be downloaded if not previously requested. |
|
""" |
|
parsed_url = urlparse(url) |
|
dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) |
|
makedir(dirname) |
|
filename = url.split("/")[-1] |
|
cached = os.path.join(dirname, filename) |
|
with file_lock(cached): |
|
if not os.path.isfile(cached): |
|
logging.info(f"Downloading {url} to {cached} ...") |
|
cached = download(url, dirname, filename=filename) |
|
logging.info(f"URL {url} cached in {cached}") |
|
return cached |
|
|
|
|
|
|
|
def create_file_symlink(file1, file2): |
|
""" |
|
Simply create the symlinks for a given file1 to file2. |
|
Useful during model checkpointing to symlinks to the |
|
latest successful checkpoint. |
|
""" |
|
try: |
|
if g_pathmgr.exists(file2): |
|
g_pathmgr.rm(file2) |
|
g_pathmgr.symlink(file1, file2) |
|
except Exception as e: |
|
logging.info(f"Could NOT create symlink. Error: {e}") |
|
|
|
|
|
def save_file(data, filename, append_to_json=True, verbose=True): |
|
""" |
|
Common i/o utility to handle saving data to various file formats. |
|
Supported: |
|
.pkl, .pickle, .npy, .json |
|
Specifically for .json, users have the option to either append (default) |
|
or rewrite by passing in Boolean value to append_to_json. |
|
""" |
|
if verbose: |
|
logging.info(f"Saving data to file: {filename}") |
|
file_ext = os.path.splitext(filename)[1] |
|
if file_ext in [".pkl", ".pickle"]: |
|
with g_pathmgr.open(filename, "wb") as fopen: |
|
pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) |
|
elif file_ext == ".npy": |
|
with g_pathmgr.open(filename, "wb") as fopen: |
|
np.save(fopen, data) |
|
elif file_ext == ".json": |
|
if append_to_json: |
|
with g_pathmgr.open(filename, "a") as fopen: |
|
fopen.write(json.dumps(data, sort_keys=True) + "\n") |
|
fopen.flush() |
|
else: |
|
with g_pathmgr.open(filename, "w") as fopen: |
|
fopen.write(json.dumps(data, sort_keys=True) + "\n") |
|
fopen.flush() |
|
elif file_ext == ".yaml": |
|
with g_pathmgr.open(filename, "w") as fopen: |
|
dump = yaml.dump(data) |
|
fopen.write(dump) |
|
fopen.flush() |
|
else: |
|
raise Exception(f"Saving {file_ext} is not supported yet") |
|
|
|
if verbose: |
|
logging.info(f"Saved data to file: {filename}") |
|
|
|
|
|
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): |
|
""" |
|
Common i/o utility to handle loading data from various file formats. |
|
Supported: |
|
.pkl, .pickle, .npy, .json |
|
For the npy files, we support reading the files in mmap_mode. |
|
If the mmap_mode of reading is not successful, we load data without the |
|
mmap_mode. |
|
""" |
|
if verbose: |
|
logging.info(f"Loading data from file: {filename}") |
|
|
|
file_ext = os.path.splitext(filename)[1] |
|
if file_ext == ".txt": |
|
with g_pathmgr.open(filename, "r") as fopen: |
|
data = fopen.readlines() |
|
elif file_ext in [".pkl", ".pickle"]: |
|
with g_pathmgr.open(filename, "rb") as fopen: |
|
data = pickle.load(fopen, encoding="latin1") |
|
elif file_ext == ".npy": |
|
if mmap_mode: |
|
try: |
|
with g_pathmgr.open(filename, "rb") as fopen: |
|
data = np.load( |
|
fopen, |
|
allow_pickle=allow_pickle, |
|
encoding="latin1", |
|
mmap_mode=mmap_mode, |
|
) |
|
except ValueError as e: |
|
logging.info( |
|
f"Could not mmap {filename}: {e}. Trying without g_pathmgr" |
|
) |
|
data = np.load( |
|
filename, |
|
allow_pickle=allow_pickle, |
|
encoding="latin1", |
|
mmap_mode=mmap_mode, |
|
) |
|
logging.info("Successfully loaded without g_pathmgr") |
|
except Exception: |
|
logging.info("Could not mmap without g_pathmgr. Trying without mmap") |
|
with g_pathmgr.open(filename, "rb") as fopen: |
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") |
|
else: |
|
with g_pathmgr.open(filename, "rb") as fopen: |
|
data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") |
|
elif file_ext == ".json": |
|
with g_pathmgr.open(filename, "r") as fopen: |
|
data = json.load(fopen) |
|
elif file_ext == ".yaml": |
|
with g_pathmgr.open(filename, "r") as fopen: |
|
data = yaml.load(fopen, Loader=yaml.FullLoader) |
|
elif file_ext == ".csv": |
|
with g_pathmgr.open(filename, "r") as fopen: |
|
data = pd.read_csv(fopen) |
|
else: |
|
raise Exception(f"Reading from {file_ext} is not supported yet") |
|
return data |
|
|
|
|
|
def abspath(resource_path: str): |
|
""" |
|
Make a path absolute, but take into account prefixes like |
|
"http://" or "manifold://" |
|
""" |
|
regex = re.compile(r"^\w+://") |
|
if regex.match(resource_path) is None: |
|
return os.path.abspath(resource_path) |
|
else: |
|
return resource_path |
|
|
|
|
|
def makedir(dir_path): |
|
""" |
|
Create the directory if it does not exist. |
|
""" |
|
is_success = False |
|
try: |
|
if not g_pathmgr.exists(dir_path): |
|
g_pathmgr.mkdirs(dir_path) |
|
is_success = True |
|
except BaseException: |
|
logging.info(f"Error creating directory: {dir_path}") |
|
return is_success |
|
|
|
|
|
def is_url(input_url): |
|
""" |
|
Check if an input string is a url. look for http(s):// and ignoring the case |
|
""" |
|
is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None |
|
return is_url |
|
|
|
|
|
def cleanup_dir(dir): |
|
""" |
|
Utility for deleting a directory. Useful for cleaning the storage space |
|
that contains various training artifacts like checkpoints, data etc. |
|
""" |
|
if os.path.exists(dir): |
|
logging.info(f"Deleting directory: {dir}") |
|
shutil.rmtree(dir) |
|
logging.info(f"Deleted contents of directory: {dir}") |
|
|
|
|
|
def get_file_size(filename): |
|
""" |
|
Given a file, get the size of file in MB |
|
""" |
|
size_in_mb = os.path.getsize(filename) / float(1024**2) |
|
return size_in_mb |
|
|
|
from typing import Dict, List, Protocol, Tuple |
|
|
|
import torch |
|
from torch.func import functional_call |
|
|
|
from vllm.multimodal import BatchedTensors |
|
from vllm.utils import is_pin_memory_available |
|
|
|
|
|
def merge_vision_embeddings(input_ids: torch.Tensor, |
|
inputs_embeds: torch.Tensor, |
|
vision_embeddings: BatchedTensors, |
|
image_token_id: int) -> torch.Tensor: |
|
""" |
|
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions |
|
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`. |
|
|
|
Note: |
|
This updates `inputs_embeds` in place. |
|
""" |
|
mask = (input_ids == image_token_id) |
|
num_expected_tokens = mask.sum() |
|
|
|
if isinstance(vision_embeddings, torch.Tensor): |
|
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape |
|
total_tokens = batch_size * batch_tokens |
|
if num_expected_tokens != total_tokens: |
|
expr = f"{batch_size} x {batch_tokens}" |
|
raise ValueError( |
|
f"Attempted to assign {expr} = {total_tokens} " |
|
f"image tokens to {num_expected_tokens} placeholders") |
|
|
|
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim) |
|
else: |
|
size_per_batch = [t.shape[0] for t in vision_embeddings] |
|
total_tokens = sum(size_per_batch) |
|
if num_expected_tokens != total_tokens: |
|
expr = ' + '.join(map(str, size_per_batch)) |
|
raise ValueError( |
|
f"Attempted to assign {expr} = {total_tokens} " |
|
f"image tokens to {num_expected_tokens} placeholders") |
|
|
|
inputs_embeds[mask] = torch.cat(vision_embeddings) |
|
|
|
return inputs_embeds |
|
|