|
|
|
|
|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import Union, List |
|
import urllib |
|
|
|
import torch |
|
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode |
|
from tqdm import tqdm |
|
|
|
from clip import _tokenizer |
|
from clip.model import convert_weights, CLIP, restore_model |
|
|
|
__all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"] |
|
|
|
_MODELS = { |
|
"ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt", |
|
"ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt", |
|
"RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt", |
|
} |
|
_MODEL_INFO = { |
|
"ViT-B-16": { |
|
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese", |
|
"input_resolution": 224 |
|
}, |
|
"ViT-L-14": { |
|
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese", |
|
"input_resolution": 224 |
|
}, |
|
"RN50": { |
|
"struct": "RN50@RBT3-chinese", |
|
"input_resolution": 224 |
|
}, |
|
} |
|
|
|
|
|
def _download(url: str, root: str): |
|
os.makedirs(root, exist_ok=True) |
|
filename = os.path.basename(url) |
|
|
|
download_target = os.path.join(root, filename) |
|
|
|
if os.path.exists(download_target) and not os.path.isfile(download_target): |
|
raise RuntimeError(f"{download_target} exists and is not a regular file") |
|
|
|
if os.path.isfile(download_target): |
|
return download_target |
|
|
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
|
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, |
|
unit_divisor=1024) as loop: |
|
while True: |
|
buffer = source.read(8192) |
|
if not buffer: |
|
break |
|
|
|
output.write(buffer) |
|
loop.update(len(buffer)) |
|
|
|
return download_target |
|
|
|
|
|
def _convert_image_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
def available_models() -> List[str]: |
|
"""Returns the names of available CLIP models""" |
|
return list(_MODELS.keys()) |
|
|
|
|
|
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
|
download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None): |
|
if name in _MODELS: |
|
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) |
|
model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution'] |
|
elif os.path.isfile(name): |
|
assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'" |
|
model_path = name |
|
model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution |
|
else: |
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |
|
|
|
with open(model_path, 'rb') as opened_file: |
|
|
|
checkpoint = torch.load(opened_file, map_location="cpu") |
|
|
|
model = create_model(model_name, checkpoint) |
|
if str(device) == "cpu": |
|
model.float() |
|
else: |
|
model.to(device) |
|
return model, image_transform(model_input_resolution) |
|
|
|
|
|
def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None, |
|
bert_path=None, use_flash_attention=False): |
|
"""Load CLIP and BERT model weights |
|
""" |
|
|
|
bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None |
|
clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None |
|
|
|
restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device) |
|
|
|
if str(device) == "cpu": |
|
model.float() |
|
return model |
|
|
|
|
|
def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor: |
|
""" |
|
Returns the tokenized representation of given input string(s) |
|
Parameters |
|
---------- |
|
texts : Union[str, List[str]] |
|
An input string or a list of input strings to tokenize |
|
context_length : int |
|
The context length to use; all baseline models use 52 as the context length |
|
Returns |
|
------- |
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
all_tokens = [] |
|
for text in texts: |
|
all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[ |
|
:context_length - 2] + [_tokenizer.vocab['[SEP]']]) |
|
|
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
|
|
for i, tokens in enumerate(all_tokens): |
|
assert len(tokens) <= context_length |
|
result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
|
return result |
|
|
|
|
|
def _convert_to_rgb(image): |
|
return image.convert('RGB') |
|
|
|
|
|
def image_transform(image_size=224): |
|
transform = Compose([ |
|
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), |
|
_convert_to_rgb, |
|
ToTensor(), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
return transform |
|
|
|
|
|
def create_model(model_name, checkpoint=None): |
|
vision_model, text_model = model_name.split('@') |
|
|
|
vision_model_config_file = Path( |
|
__file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json" |
|
print('Loading vision model config from', vision_model_config_file) |
|
assert os.path.exists(vision_model_config_file) |
|
|
|
text_model_config_file = Path( |
|
__file__).parent / f"model_configs/{text_model.replace('/', '-')}.json" |
|
print('Loading text model config from', text_model_config_file) |
|
assert os.path.exists(text_model_config_file) |
|
|
|
with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: |
|
model_info = json.load(fv) |
|
for k, v in json.load(ft).items(): |
|
model_info[k] = v |
|
if isinstance(model_info['vision_layers'], str): |
|
model_info['vision_layers'] = eval(model_info['vision_layers']) |
|
print('Model info', model_info) |
|
model = CLIP(**model_info) |
|
convert_weights(model) |
|
if checkpoint: |
|
sd = checkpoint["state_dict"] |
|
if next(iter(sd.items()))[0].startswith('module'): |
|
sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} |
|
model.load_state_dict(sd) |
|
return model |
|
|