deepanway's picture
Uplaod files
f1069cc
raw history blame
No virus
5.09 kB
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import Union, List
import torch
from .model import build_model_from_openai_state_dict
from .pretrained import (
get_pretrained_url,
list_pretrained_tag_models,
download_pretrained,
)
__all__ = ["list_openai_models", "load_openai_model"]
CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache")
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_tag_models("openai")
def load_openai_model(
name: str,
model_cfg,
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
jit=True,
cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"),
enable_fusion: bool = False,
fusion_type: str = "None",
):
"""Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLAP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if get_pretrained_url(name, "openai"):
model_path = download_pretrained(
get_pretrained_url(name, "openai"), root=cache_dir
)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(
f"Model {name} not found; available models = {list_openai_models()}"
)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(
f"File {model_path} is not a JIT archive. Loading as a state dict instead"
)
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
try:
model = build_model_from_openai_state_dict(
state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
).to(device)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(
sd, model_cfg, enable_fusion, fusion_type
).to(device)
if str(device) == "cpu":
model.float()
return model
# patch the device names
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
)
device_node = [
n
for n in device_holder.graph.findAllNodes("prim::Constant")
if "Device" in repr(n)
][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith(
"cuda"
):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_audio)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[]
)
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [
1,
2,
]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_audio)
patch_float(model.encode_text)
model.float()
model.audio_branch.audio_length = model.audio_cfg.audio_length
return model