|
import json |
|
import logging |
|
import os |
|
import pathlib |
|
import re |
|
from copy import deepcopy |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from .model import CLAP, convert_weights_to_fp16 |
|
from .openai import load_openai_model |
|
from .pretrained import get_pretrained_url, download_pretrained |
|
from .transform import image_transform |
|
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] |
|
_MODEL_CONFIGS = {} |
|
|
|
|
|
def _natural_key(string_): |
|
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] |
|
|
|
|
|
def _rescan_model_configs(): |
|
global _MODEL_CONFIGS |
|
|
|
config_ext = (".json",) |
|
config_files = [] |
|
for config_path in _MODEL_CONFIG_PATHS: |
|
if config_path.is_file() and config_path.suffix in config_ext: |
|
config_files.append(config_path) |
|
elif config_path.is_dir(): |
|
for ext in config_ext: |
|
config_files.extend(config_path.glob(f"*{ext}")) |
|
|
|
for cf in config_files: |
|
if os.path.basename(cf)[0] == ".": |
|
continue |
|
|
|
with open(cf, "r") as f: |
|
model_cfg = json.load(f) |
|
if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): |
|
_MODEL_CONFIGS[cf.stem] = model_cfg |
|
|
|
_MODEL_CONFIGS = { |
|
k: v |
|
for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) |
|
} |
|
|
|
|
|
_rescan_model_configs() |
|
|
|
|
|
def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): |
|
checkpoint = torch.load(checkpoint_path, map_location=map_location) |
|
if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
if skip_params: |
|
if next(iter(state_dict.items()))[0].startswith("module"): |
|
state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
|
return state_dict |
|
|
|
|
|
def create_model( |
|
amodel_name: str, |
|
tmodel_name: str, |
|
pretrained: str = "", |
|
precision: str = "fp32", |
|
device: torch.device = torch.device("cpu"), |
|
jit: bool = False, |
|
force_quick_gelu: bool = False, |
|
openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), |
|
skip_params=True, |
|
pretrained_audio: str = "", |
|
pretrained_text: str = "", |
|
enable_fusion: bool = False, |
|
fusion_type: str = "None" |
|
|
|
): |
|
amodel_name = amodel_name.replace( |
|
"/", "-" |
|
) |
|
pretrained_orig = pretrained |
|
pretrained = pretrained.lower() |
|
if pretrained == "openai": |
|
if amodel_name in _MODEL_CONFIGS: |
|
logging.info(f"Loading {amodel_name} model config.") |
|
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) |
|
else: |
|
logging.error( |
|
f"Model config for {amodel_name} not found; available models {list_models()}." |
|
) |
|
raise RuntimeError(f"Model config for {amodel_name} not found.") |
|
|
|
logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") |
|
|
|
model_cfg["text_cfg"]["model_type"] = tmodel_name |
|
model = load_openai_model( |
|
"ViT-B-16", |
|
model_cfg, |
|
device=device, |
|
jit=jit, |
|
cache_dir=openai_model_cache_dir, |
|
enable_fusion=enable_fusion, |
|
fusion_type=fusion_type, |
|
) |
|
|
|
if precision == "amp" or precision == "fp32": |
|
model = model.float() |
|
else: |
|
if amodel_name in _MODEL_CONFIGS: |
|
logging.info(f"Loading {amodel_name} model config.") |
|
model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) |
|
else: |
|
logging.error( |
|
f"Model config for {amodel_name} not found; available models {list_models()}." |
|
) |
|
raise RuntimeError(f"Model config for {amodel_name} not found.") |
|
|
|
if force_quick_gelu: |
|
|
|
model_cfg["quick_gelu"] = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_cfg["text_cfg"]["model_type"] = tmodel_name |
|
model_cfg["enable_fusion"] = enable_fusion |
|
model_cfg["fusion_type"] = fusion_type |
|
model = CLAP(**model_cfg) |
|
|
|
if pretrained: |
|
checkpoint_path = "" |
|
url = get_pretrained_url(amodel_name, pretrained) |
|
if url: |
|
checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) |
|
elif os.path.exists(pretrained_orig): |
|
checkpoint_path = pretrained_orig |
|
if checkpoint_path: |
|
logging.info( |
|
f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." |
|
) |
|
ckpt = load_state_dict(checkpoint_path, skip_params=True) |
|
model.load_state_dict(ckpt) |
|
param_names = [n for n, p in model.named_parameters()] |
|
|
|
|
|
else: |
|
logging.warning( |
|
f"Pretrained weights ({pretrained}) not found for model {amodel_name}." |
|
) |
|
raise RuntimeError( |
|
f"Pretrained weights ({pretrained}) not found for model {amodel_name}." |
|
) |
|
|
|
if pretrained_audio: |
|
if amodel_name.startswith("PANN"): |
|
if "Cnn14_mAP" in pretrained_audio: |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
audio_ckpt = audio_ckpt["model"] |
|
keys = list(audio_ckpt.keys()) |
|
for key in keys: |
|
if ( |
|
"spectrogram_extractor" not in key |
|
and "logmel_extractor" not in key |
|
): |
|
v = audio_ckpt.pop(key) |
|
audio_ckpt["audio_branch." + key] = v |
|
elif os.path.basename(pretrained_audio).startswith( |
|
"PANN" |
|
): |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
audio_ckpt = audio_ckpt["state_dict"] |
|
keys = list(audio_ckpt.keys()) |
|
for key in keys: |
|
if key.startswith("sed_model"): |
|
v = audio_ckpt.pop(key) |
|
audio_ckpt["audio_branch." + key[10:]] = v |
|
elif os.path.basename(pretrained_audio).startswith( |
|
"finetuned" |
|
): |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
else: |
|
raise ValueError("Unknown audio checkpoint") |
|
elif amodel_name.startswith("HTSAT"): |
|
if "HTSAT_AudioSet_Saved" in pretrained_audio: |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
audio_ckpt = audio_ckpt["state_dict"] |
|
keys = list(audio_ckpt.keys()) |
|
for key in keys: |
|
if key.startswith("sed_model") and ( |
|
"spectrogram_extractor" not in key |
|
and "logmel_extractor" not in key |
|
): |
|
v = audio_ckpt.pop(key) |
|
audio_ckpt["audio_branch." + key[10:]] = v |
|
elif os.path.basename(pretrained_audio).startswith( |
|
"HTSAT" |
|
): |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
audio_ckpt = audio_ckpt["state_dict"] |
|
keys = list(audio_ckpt.keys()) |
|
for key in keys: |
|
if key.startswith("sed_model"): |
|
v = audio_ckpt.pop(key) |
|
audio_ckpt["audio_branch." + key[10:]] = v |
|
elif os.path.basename(pretrained_audio).startswith( |
|
"finetuned" |
|
): |
|
audio_ckpt = torch.load(pretrained_audio, map_location="cpu") |
|
else: |
|
raise ValueError("Unknown audio checkpoint") |
|
else: |
|
raise f"this audio encoder pretrained checkpoint is not support" |
|
|
|
model.load_state_dict(audio_ckpt, strict=False) |
|
logging.info( |
|
f"Loading pretrained {amodel_name} weights ({pretrained_audio})." |
|
) |
|
param_names = [n for n, p in model.named_parameters()] |
|
for n in param_names: |
|
print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") |
|
|
|
model.to(device=device) |
|
if precision == "fp16": |
|
assert device.type != "cpu" |
|
convert_weights_to_fp16(model) |
|
|
|
if jit: |
|
model = torch.jit.script(model) |
|
|
|
return model, model_cfg |
|
|
|
|
|
def create_model_and_transforms( |
|
model_name: str, |
|
pretrained: str = "", |
|
precision: str = "fp32", |
|
device: torch.device = torch.device("cpu"), |
|
jit: bool = False, |
|
force_quick_gelu: bool = False, |
|
|
|
): |
|
model = create_model( |
|
model_name, |
|
pretrained, |
|
precision, |
|
device, |
|
jit, |
|
force_quick_gelu=force_quick_gelu, |
|
|
|
) |
|
preprocess_train = image_transform(model.visual.image_size, is_train=True) |
|
preprocess_val = image_transform(model.visual.image_size, is_train=False) |
|
return model, preprocess_train, preprocess_val |
|
|
|
|
|
def list_models(): |
|
"""enumerate available model architectures based on config files""" |
|
return list(_MODEL_CONFIGS.keys()) |
|
|
|
|
|
def add_model_config(path): |
|
"""add model config path or file and update registry""" |
|
if not isinstance(path, Path): |
|
path = Path(path) |
|
_MODEL_CONFIG_PATHS.append(path) |
|
_rescan_model_configs() |
|
|