|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Generator, List, Optional |
|
|
|
import huggingface_hub |
|
import yaml |
|
from huggingface_hub.utils import HFValidationError |
|
from yaml.nodes import SequenceNode as SequenceNode |
|
|
|
from mergekit.config import MergeConfiguration, ModelReference |
|
|
|
CARD_TEMPLATE = """--- |
|
{metadata} |
|
--- |
|
# {name} |
|
|
|
This is a merge of pre-trained language models created using [mergekit](https://github.com/cg123/mergekit). |
|
|
|
## Merge Details |
|
### Merge Method |
|
|
|
This model was merged using the {merge_method} merge method{base_text}. |
|
|
|
### Models Merged |
|
|
|
The following models were included in the merge: |
|
{model_list} |
|
|
|
### Configuration |
|
|
|
The following YAML configuration was used to produce this model: |
|
|
|
```yaml |
|
{config_yaml} |
|
``` |
|
""" |
|
|
|
|
|
def is_hf(path: str) -> bool: |
|
""" |
|
Determines if the given path is a Hugging Face model repository. |
|
|
|
Args: |
|
path: A string path to check. |
|
""" |
|
if path[0] in "/~" or path.count("/") > 1: |
|
return False |
|
if not os.path.exists(path): |
|
return True |
|
try: |
|
return huggingface_hub.repo_exists(path, repo_type="model", token=False) |
|
except HFValidationError: |
|
return False |
|
|
|
|
|
def extract_hf_paths(models: List[ModelReference]) -> Generator[str, None, None]: |
|
""" |
|
Yields all valid Hugging Face paths from a list of ModelReference objects. |
|
|
|
Args: |
|
models: A list of ModelReference objects. |
|
""" |
|
for model in models: |
|
if is_hf(model.model.path): |
|
yield model.model.path |
|
|
|
if model.lora and is_hf(model.lora.path): |
|
yield model.lora.path |
|
|
|
|
|
def method_md(merge_method: str) -> str: |
|
""" |
|
Returns a markdown string for the given merge method. |
|
|
|
Args: |
|
merge_method: A string indicating the merge method used. |
|
""" |
|
methods = { |
|
"linear": "[linear](https://arxiv.org/abs/2203.05482)", |
|
"ties": "[TIES](https://arxiv.org/abs/2306.01708)", |
|
"slerp": "SLERP", |
|
"task_arithmetic": "[task arithmetic](https://arxiv.org/abs/2212.04089)", |
|
"dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)", |
|
"dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)", |
|
"model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)", |
|
} |
|
return methods.get(merge_method, merge_method) |
|
|
|
|
|
def maybe_link_hf(path: str) -> str: |
|
""" |
|
Convert a path to a clickable link if it's a Hugging Face model path. |
|
|
|
Args: |
|
path: A string path to possibly convert to a link. |
|
""" |
|
if is_hf(path): |
|
return f"[{path}](https://huggingface.co/{path})" |
|
return path |
|
|
|
|
|
def modelref_md(model: ModelReference) -> str: |
|
""" |
|
Generates markdown description for a ModelReference object. |
|
|
|
Args: |
|
model: A ModelReference object. |
|
|
|
Returns: |
|
A markdown formatted string describing the model reference. |
|
""" |
|
text = maybe_link_hf(model.model.path) |
|
if model.lora: |
|
text += " + " + maybe_link_hf(model.lora.path) |
|
return text |
|
|
|
|
|
def generate_card( |
|
config: MergeConfiguration, |
|
config_yaml: str, |
|
name: Optional[str] = None, |
|
) -> str: |
|
""" |
|
Generates a markdown card for a merged model configuration. |
|
|
|
Args: |
|
config: A MergeConfiguration object. |
|
config_yaml: YAML source text of the config. |
|
name: An optional name for the model. |
|
""" |
|
if not name: |
|
name = "Untitled Model (1)" |
|
|
|
hf_bases = list(extract_hf_paths(config.referenced_models())) |
|
tags = ["mergekit", "merge"] |
|
|
|
actual_base = config.base_model |
|
if config.merge_method == "slerp": |
|
|
|
actual_base = None |
|
|
|
base_text = "" |
|
if actual_base: |
|
base_text = f" using {modelref_md(actual_base)} as a base" |
|
|
|
model_bullets = [] |
|
for model in config.referenced_models(): |
|
if model == actual_base: |
|
|
|
continue |
|
|
|
model_bullets.append("* " + modelref_md(model)) |
|
|
|
return CARD_TEMPLATE.format( |
|
metadata=yaml.dump( |
|
{"base_model": hf_bases, "tags": tags, "library_name": "transformers"} |
|
), |
|
model_list="\n".join(model_bullets), |
|
base_text=base_text, |
|
merge_method=method_md(config.merge_method), |
|
name=name, |
|
config_yaml=config_yaml, |
|
) |
|
|