QA-CLIP / clip /utils.py
kunyi
Upload 30 files
f76d30f
raw
history blame
No virus
6.72 kB
# Code modified from https://github.com/openai/CLIP
import json
import os
from pathlib import Path
from typing import Union, List
import urllib
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode
from tqdm import tqdm
from clip import _tokenizer
from clip.model import convert_weights, CLIP, restore_model
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"]
_MODELS = {
"ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt",
"ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt",
"RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt",
}
_MODEL_INFO = {
"ViT-B-16": {
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"ViT-L-14": {
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"RN50": {
"struct": "RN50@RBT3-chinese",
"input_resolution": 224
},
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True,
unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None):
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
elif os.path.isfile(name):
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
model_path = name
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
# loading saved checkpoint
checkpoint = torch.load(opened_file, map_location="cpu")
model = create_model(model_name, checkpoint)
if str(device) == "cpu":
model.float()
else:
model.to(device)
return model, image_transform(model_input_resolution)
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None,
bert_path=None, use_flash_attention=False):
"""Load CLIP and BERT model weights
"""
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device)
if str(device) == "cpu":
model.float()
return model
def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all baseline models use 52 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[
:context_length - 2] + [_tokenizer.vocab['[SEP]']])
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
return transform
def create_model(model_name, checkpoint=None):
vision_model, text_model = model_name.split('@')
# Initialize the model.
vision_model_config_file = Path(
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
print('Loading vision model config from', vision_model_config_file)
assert os.path.exists(vision_model_config_file)
text_model_config_file = Path(
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
print('Loading text model config from', text_model_config_file)
assert os.path.exists(text_model_config_file)
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
for k, v in json.load(ft).items():
model_info[k] = v
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
print('Model info', model_info)
model = CLIP(**model_info)
convert_weights(model)
if checkpoint:
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
model.load_state_dict(sd)
return model