| | import os, torch |
| | from typing import List, Tuple, Optional, Union, Dict |
| |
|
| | from .ebc import _ebc, EBC |
| | from .clip_ebc import _clip_ebc, CLIP_EBC |
| |
|
| |
|
| | def get_model( |
| | model_info_path: str, |
| | model_name: Optional[str] = None, |
| | block_size: Optional[int] = None, |
| | bins: Optional[List[Tuple[float, float]]] = None, |
| | bin_centers: Optional[List[float]] = None, |
| | zero_inflated: Optional[bool] = True, |
| | |
| | clip_weight_name: Optional[str] = None, |
| | num_vpt: Optional[int] = None, |
| | vpt_drop: Optional[float] = None, |
| | input_size: Optional[int] = None, |
| | adapter: bool = False, |
| | adapter_reduction: Optional[int] = None, |
| | lora: bool = False, |
| | lora_rank: Optional[int] = None, |
| | lora_alpha: Optional[int] = None, |
| | lora_dropout: Optional[float] = None, |
| | norm: str = "none", |
| | act: str = "none", |
| | text_prompts: Optional[List[str]] = None |
| | ) -> Union[EBC, CLIP_EBC]: |
| | if os.path.exists(model_info_path): |
| | model_info = torch.load(model_info_path, map_location="cpu", weights_only=False) |
| |
|
| | model_name = model_info["config"]["model_name"] |
| | block_size = model_info["config"]["block_size"] |
| | bins = model_info["config"]["bins"] |
| | bin_centers = model_info["config"]["bin_centers"] |
| | zero_inflated = model_info["config"]["zero_inflated"] |
| |
|
| | clip_weight_name = model_info["config"].get("clip_weight_name", None) |
| |
|
| | num_vpt = model_info["config"].get("num_vpt", None) |
| | vpt_drop = model_info["config"].get("vpt_drop", None) |
| |
|
| |
|
| | adapter = model_info["config"].get("adapter", False) |
| | adapter_reduction = model_info["config"].get("adapter_reduction", None) |
| |
|
| | lora = model_info["config"].get("lora", False) |
| | lora_rank = model_info["config"].get("lora_rank", None) |
| | lora_alpha = model_info["config"].get("lora_alpha", None) |
| | lora_dropout = model_info["config"].get("lora_dropout", None) |
| |
|
| | input_size = model_info["config"].get("input_size", None) |
| | text_prompts = model_info["config"].get("text_prompts", None) |
| |
|
| | norm = model_info["config"].get("norm", "none") |
| | act = model_info["config"].get("act", "none") |
| |
|
| | weights = model_info["weights"] |
| |
|
| | else: |
| | assert model_name is not None, "model_name should be provided if model_info_path is not provided" |
| | assert block_size is not None, "block_size should be provided" |
| | assert bins is not None, "bins should be provided" |
| | assert bin_centers is not None, "bin_centers should be provided" |
| | weights = None |
| |
|
| | if "ViT" in model_name: |
| | assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}" |
| | assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}" |
| |
|
| | if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"): |
| | assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}" |
| | model = _clip_ebc( |
| | model_name=model_name[5:], |
| | weight_name=clip_weight_name, |
| | block_size=block_size, |
| | bins=bins, |
| | bin_centers=bin_centers, |
| | zero_inflated=zero_inflated, |
| | num_vpt=num_vpt, |
| | vpt_drop=vpt_drop, |
| | input_size=input_size, |
| | adapter=adapter, |
| | adapter_reduction=adapter_reduction, |
| | lora=lora, |
| | lora_rank=lora_rank, |
| | lora_alpha=lora_alpha, |
| | lora_dropout=lora_dropout, |
| | text_prompts=text_prompts, |
| | norm=norm, |
| | act=act |
| | ) |
| | model_config = { |
| | "model_name": model_name, |
| | "block_size": block_size, |
| | "bins": bins, |
| | "bin_centers": bin_centers, |
| | "zero_inflated": zero_inflated, |
| | "clip_weight_name": clip_weight_name, |
| | "num_vpt": num_vpt, |
| | "vpt_drop": vpt_drop, |
| | "input_size": input_size, |
| | "adapter": adapter, |
| | "adapter_reduction": adapter_reduction, |
| | "lora": lora, |
| | "lora_rank": lora_rank, |
| | "lora_alpha": lora_alpha, |
| | "lora_dropout": lora_dropout, |
| | "text_prompts": model.text_prompts, |
| | "norm": norm, |
| | "act": act |
| | } |
| | |
| | else: |
| | assert not adapter, "adapter for non-CLIP models is not implemented yet" |
| | assert not lora, "lora for non-CLIP models is not implemented yet" |
| | model = _ebc( |
| | model_name=model_name, |
| | block_size=block_size, |
| | bins=bins, |
| | bin_centers=bin_centers, |
| | zero_inflated=zero_inflated, |
| | num_vpt=num_vpt, |
| | vpt_drop=vpt_drop, |
| | input_size=input_size, |
| | norm=norm, |
| | act=act |
| | ) |
| | model_config = { |
| | "model_name": model_name, |
| | "block_size": block_size, |
| | "bins": bins, |
| | "bin_centers": bin_centers, |
| | "zero_inflated": zero_inflated, |
| | "num_vpt": num_vpt, |
| | "vpt_drop": vpt_drop, |
| | "input_size": input_size, |
| | "norm": norm, |
| | "act": act |
| | } |
| |
|
| | model.config = model_config |
| | model_info = {"config": model_config, "weights": weights} |
| |
|
| | if weights is not None: |
| | model.load_state_dict(weights) |
| |
|
| | if not os.path.exists(model_info_path): |
| | torch.save(model_info, model_info_path) |
| | |
| | return model |
| |
|
| |
|
| | __all__ = ["get_model"] |
| |
|