|
""" OpenAI pretrained model functions
|
|
|
|
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
|
"""
|
|
|
|
import os
|
|
import warnings
|
|
from typing import Union, List
|
|
|
|
import torch
|
|
|
|
from .model import build_model_from_openai_state_dict
|
|
from .pretrained import (
|
|
get_pretrained_url,
|
|
list_pretrained_tag_models,
|
|
download_pretrained,
|
|
)
|
|
|
|
__all__ = ["list_openai_models", "load_openai_model"]
|
|
|
|
|
|
def list_openai_models() -> List[str]:
|
|
"""Returns the names of available CLIP models"""
|
|
return list_pretrained_tag_models("openai")
|
|
|
|
|
|
def load_openai_model(
|
|
name: str,
|
|
model_cfg,
|
|
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
|
|
jit=True,
|
|
cache_dir=os.path.expanduser("~/.cache/clip"),
|
|
enable_fusion: bool = False,
|
|
fusion_type: str = "None",
|
|
):
|
|
"""Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
|
device : Union[str, torch.device]
|
|
The device to put the loaded model
|
|
jit : bool
|
|
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
|
|
|
Returns
|
|
-------
|
|
model : torch.nn.Module
|
|
The CLAP model
|
|
preprocess : Callable[[PIL.Image], torch.Tensor]
|
|
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
|
"""
|
|
if get_pretrained_url(name, "openai"):
|
|
model_path = download_pretrained(
|
|
get_pretrained_url(name, "openai"), root=cache_dir
|
|
)
|
|
elif os.path.isfile(name):
|
|
model_path = name
|
|
else:
|
|
raise RuntimeError(
|
|
f"Model {name} not found; available models = {list_openai_models()}"
|
|
)
|
|
|
|
try:
|
|
|
|
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
|
state_dict = None
|
|
except RuntimeError:
|
|
|
|
if jit:
|
|
warnings.warn(
|
|
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
|
|
)
|
|
jit = False
|
|
state_dict = torch.load(model_path, map_location="cpu")
|
|
|
|
if not jit:
|
|
try:
|
|
model = build_model_from_openai_state_dict(
|
|
state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
|
|
).to(device)
|
|
except KeyError:
|
|
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
|
model = build_model_from_openai_state_dict(
|
|
sd, model_cfg, enable_fusion, fusion_type
|
|
).to(device)
|
|
|
|
if str(device) == "cpu":
|
|
model.float()
|
|
return model
|
|
|
|
|
|
device_holder = torch.jit.trace(
|
|
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
|
|
)
|
|
device_node = [
|
|
n
|
|
for n in device_holder.graph.findAllNodes("prim::Constant")
|
|
if "Device" in repr(n)
|
|
][-1]
|
|
|
|
def patch_device(module):
|
|
try:
|
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
|
except RuntimeError:
|
|
graphs = []
|
|
|
|
if hasattr(module, "forward1"):
|
|
graphs.append(module.forward1.graph)
|
|
|
|
for graph in graphs:
|
|
for node in graph.findAllNodes("prim::Constant"):
|
|
if "value" in node.attributeNames() and str(node["value"]).startswith(
|
|
"cuda"
|
|
):
|
|
node.copyAttributes(device_node)
|
|
|
|
model.apply(patch_device)
|
|
patch_device(model.encode_audio)
|
|
patch_device(model.encode_text)
|
|
|
|
|
|
if str(device) == "cpu":
|
|
float_holder = torch.jit.trace(
|
|
lambda: torch.ones([]).float(), example_inputs=[]
|
|
)
|
|
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
|
float_node = float_input.node()
|
|
|
|
def patch_float(module):
|
|
try:
|
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
|
except RuntimeError:
|
|
graphs = []
|
|
|
|
if hasattr(module, "forward1"):
|
|
graphs.append(module.forward1.graph)
|
|
|
|
for graph in graphs:
|
|
for node in graph.findAllNodes("aten::to"):
|
|
inputs = list(node.inputs())
|
|
for i in [
|
|
1,
|
|
2,
|
|
]:
|
|
if inputs[i].node()["value"] == 5:
|
|
inputs[i].node().copyAttributes(float_node)
|
|
|
|
model.apply(patch_float)
|
|
patch_float(model.encode_audio)
|
|
patch_float(model.encode_text)
|
|
model.float()
|
|
|
|
model.audio_branch.audio_length = model.audio_cfg.audio_length
|
|
return model
|
|
|