|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
import typing |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import click |
|
from click.core import Context, Parameter |
|
from pydantic import BaseModel |
|
|
|
from mergekit.common import parse_kmb |
|
|
|
|
|
class MergeOptions(BaseModel): |
|
allow_crimes: bool = False |
|
transformers_cache: Optional[str] = None |
|
lora_merge_cache: Optional[str] = None |
|
cuda: bool = False |
|
low_cpu_memory: bool = False |
|
out_shard_size: int = parse_kmb("5B") |
|
copy_tokenizer: bool = True |
|
clone_tensors: bool = False |
|
trust_remote_code: bool = False |
|
random_seed: Optional[int] = None |
|
lazy_unpickle: bool = False |
|
write_model_card: bool = True |
|
safe_serialization: bool = True |
|
|
|
|
|
OPTION_HELP = { |
|
"allow_crimes": "Allow mixing architectures", |
|
"transformers_cache": "Override storage path for downloaded models", |
|
"lora_merge_cache": "Path to store merged LORA models", |
|
"cuda": "Perform matrix arithmetic on GPU", |
|
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM", |
|
"out_shard_size": "Number of parameters per output shard [default: 5B]", |
|
"copy_tokenizer": "Copy a tokenizer to the output", |
|
"clone_tensors": "Clone tensors before saving, to allow multiple occurrences of the same layer", |
|
"trust_remote_code": "Trust remote code from huggingface repos (danger)", |
|
"random_seed": "Seed for reproducible use of randomized merge methods", |
|
"lazy_unpickle": "Experimental lazy unpickler for lower memory usage", |
|
"write_model_card": "Output README.md containing details of the merge", |
|
"safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.", |
|
} |
|
|
|
|
|
class ShardSizeParamType(click.ParamType): |
|
name = "size" |
|
|
|
def convert( |
|
self, value: Any, param: Optional[Parameter], ctx: Optional[Context] |
|
) -> int: |
|
return parse_kmb(value) |
|
|
|
|
|
def add_merge_options(f: Callable) -> Callable: |
|
@functools.wraps(f) |
|
def wrapper(*args, **kwargs): |
|
arg_dict = {} |
|
for field_name in MergeOptions.model_fields: |
|
if field_name in kwargs: |
|
arg_dict[field_name] = kwargs.pop(field_name) |
|
|
|
kwargs["merge_options"] = MergeOptions(**arg_dict) |
|
f(*args, **kwargs) |
|
|
|
for field_name, info in reversed(MergeOptions.model_fields.items()): |
|
origin = typing.get_origin(info.annotation) |
|
if origin is Union: |
|
ty, prob_none = typing.get_args(info.annotation) |
|
assert prob_none is type(None) |
|
field_type = ty |
|
else: |
|
field_type = info.annotation |
|
|
|
if field_name == "out_shard_size": |
|
field_type = ShardSizeParamType() |
|
|
|
arg_name = field_name.replace("_", "-") |
|
if field_type == bool: |
|
arg_str = f"--{arg_name}/--no-{arg_name}" |
|
else: |
|
arg_str = f"--{arg_name}" |
|
|
|
help_str = OPTION_HELP.get(field_name, None) |
|
wrapper = click.option( |
|
arg_str, |
|
type=field_type, |
|
default=info.default, |
|
help=help_str, |
|
show_default=field_name != "out_shard_size", |
|
)(wrapper) |
|
|
|
return wrapper |
|
|