praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
9.64 kB
import logging
import os
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp,\
get_cast_dtype, set_model_preprocess_cfg, convert_to_custom_text_state_dict, resize_pos_embed
from open_clip.coca_model import CoCa
from open_clip.openai import load_openai_model
from open_clip.pretrained import get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from open_clip.transform import image_transform_v2, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from open_clip.factory import get_model_config, _get_hf_config, list_models, load_state_dict, load_checkpoint
HF_HUB_PREFIX = 'hf-hub:'
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = True,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
load_ckpt: Optional[bool] = True,
**model_kwargs,
):
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
if custom_text:
if "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(dtype=dtype)
from open_clip.transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(dtype=dtype)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
if load_ckpt:
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
if load_ckpt:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
load_ckpt: Optional[bool] = True,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
cache_dir=cache_dir,
require_pretrained=True,
load_ckpt=load_ckpt,
**model_kwargs,
)
if not return_transform:
return model
preprocess = image_transform_v2(
PreprocessCfg(**model.visual.preprocess_cfg),
is_train=False,
)
return model, preprocess