Spaces:
Running
Running
from tqdm import tqdm | |
import argparse | |
import requests | |
import merge | |
import os | |
import sys | |
import shutil | |
import yaml | |
from pathlib import Path | |
import gradio as gr | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description="Merge HuggingFace models") | |
parser.add_argument('repo_list', type=str, help='File containing list of repositories to merge, supports mergekit yaml or txt') | |
parser.add_argument('output_dir', type=str, help='Directory for the merged models') | |
parser.add_argument('-base_model', type=str, default='staging/base_model', help='Base model directory') | |
parser.add_argument('-staging_model', type=str, default='staging/merge_model', help='Staging model directory') | |
parser.add_argument('-p', type=float, default=0.5, help='Dropout probability') | |
parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta') | |
parser.add_argument('--dry', action='store_true', help='Run in dry mode without making any changes') | |
return parser.parse_args() | |
def repo_list_generator(file_path, default_p, default_lambda_val): | |
_, file_extension = os.path.splitext(file_path) | |
# Branching based on file extension | |
if file_extension.lower() == '.yaml' or file_extension.lower() == ".yml": | |
with open(file_path, 'r', encoding='utf-8') as file: | |
data = yaml.safe_load(file) | |
for model_info in data['models']: | |
model_name = model_info['model'] | |
p = model_info.get('parameters', {}).get('weight', default_p) | |
lambda_val = 1 / model_info.get('parameters', {}).get('density', default_lambda_val) | |
yield model_name, p, lambda_val | |
else: # Defaulting to txt file processing | |
with open(file_path, "r", encoding='utf-8') as file: | |
repos_to_process = file.readlines() | |
for repo in repos_to_process: | |
yield repo.strip(), default_p, default_lambda_val | |
def reset_directories(directories, dry_run): | |
for directory in directories: | |
if os.path.exists(directory): | |
if dry_run: | |
print(f"[DRY RUN] Would delete directory {directory}") | |
else: | |
shutil.rmtree(directory) | |
print(f"Directory {directory} deleted successfully.") | |
def do_merge(tensor_map, staging_path, p, lambda_val, dry_run=False): | |
if dry_run: | |
print(f"[DRY RUN] Would merge with {staging_path}") | |
else: | |
try: | |
print(f"Merge operation for {staging_path}") | |
tensor_map = merge.merge_folder(tensor_map, staging_path, p, lambda_val) | |
print("Merge operation completed successfully.") | |
except Exception as e: | |
print(f"Error during merge operation: {e}") | |
return tensor_map | |
def do_merge_files(base_path, staging_path, output_path, p, lambda_val, dry_run=False): | |
if dry_run: | |
print(f"[DRY RUN] Would merge with {staging_path}") | |
else: | |
try: | |
print(f"Merge operation for {staging_path}") | |
tensor_map = merge.merge_files(base_path, staging_path, output_path, p, lambda_val) | |
print("Merge operation completed successfully.") | |
except Exception as e: | |
print(f"Error during merge operation: {e}") | |
return tensor_map | |
def do_merge_diffusers(tensor_map, staging_path, p, lambda_val, skip_dirs, dry_run=False): | |
if dry_run: | |
print(f"[DRY RUN] Would merge with {staging_path}") | |
else: | |
try: | |
print(f"Merge operation for {staging_path}") | |
tensor_map = merge.merge_folder_diffusers(tensor_map, staging_path, p, lambda_val, skip_dirs) | |
print("Merge operation completed successfully.") | |
except Exception as e: | |
print(f"Error during merge operation: {e}") | |
return tensor_map | |
def download_repo(repo_name, path, dry_run=False): | |
from huggingface_hub import snapshot_download | |
if dry_run: | |
print(f"[DRY RUN] Would download repository {repo_name} to {path}") | |
else: | |
print(f"Repository {repo_name} cloning.") | |
try: | |
snapshot_download(repo_id=repo_name, local_dir=path) | |
except Exception as e: | |
print(e) | |
return | |
print(f"Repository {repo_name} cloned successfully.") | |
def download_thing(directory, url, progress=gr.Progress(track_tqdm=True)): | |
civitai_api_key= os.environ.get("CIVITAI_API_KEY") | |
url = url.strip() | |
if "drive.google.com" in url: | |
original_dir = os.getcwd() | |
os.chdir(directory) | |
os.system(f"gdown --fuzzy {url}") | |
os.chdir(original_dir) | |
elif "huggingface.co" in url: | |
url = url.replace("?download=true", "") | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
else: | |
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
elif "civitai.com" in url: | |
if "?" in url: | |
url = url.split("?")[0] | |
if civitai_api_key: | |
url = url + f"?token={civitai_api_key}" | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
else: | |
print("You need an API key to download Civitai models.") | |
else: | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
def get_local_model_list(dir_path): | |
model_list = [] | |
valid_extensions = ('.safetensors') | |
for file in Path(dir_path).glob("*"): | |
if file.suffix in valid_extensions: | |
file_path = str(Path(f"{dir_path}/{file.name}")) | |
model_list.append(file_path) | |
return model_list | |
def list_sub(a, b): | |
return [e for e in a if e not in b] | |
def get_download_file(temp_dir, url): | |
new_file = None | |
if not "http" in url and Path(url).exists(): | |
print(f"Use local file: {url}") | |
new_file = url | |
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): | |
print(f"File to download alreday exists: {url}") | |
new_file = f"{temp_dir}/{url.split('/')[-1]}" | |
else: | |
print(f"Start downloading: {url}") | |
before = get_local_model_list(temp_dir) | |
try: | |
download_thing(temp_dir, url.strip()) | |
except Exception: | |
print(f"Download failed: {url}") | |
return None | |
after = get_local_model_list(temp_dir) | |
new_file = list_sub(after, before)[0] if list_sub(after, before) else None | |
if new_file is None: | |
print(f"Download failed: {url}") | |
return None | |
print(f"Download completed: {url}") | |
return new_file | |
def download_file(url, path, dry_run=False): | |
if dry_run: | |
print(f"[DRY RUN] Would download file {url} to {path}") | |
else: | |
print(f"File {url} cloning.") | |
try: | |
path = get_download_file(path, url) | |
except Exception as e: | |
print(e) | |
return None | |
print(f"File {url} cloned successfully.") | |
return path | |
def is_repo_name(s): | |
import re | |
return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s) | |
def should_create_symlink(repo_name): | |
if os.path.exists(repo_name): | |
return True, os.path.isfile(repo_name) | |
return False, False | |
def download_or_link_repo(repo_name, path, dry_run=False): | |
symlink, is_file = should_create_symlink(repo_name) | |
if symlink and is_file: | |
os.makedirs(path, exist_ok=True) | |
symlink_path = os.path.join(path, os.path.basename(repo_name)) | |
os.symlink(repo_name, symlink_path) | |
elif symlink: | |
os.symlink(repo_name, path) | |
elif "http" in repo_name: | |
return download_file(repo_name, path, dry_run) | |
elif is_repo_name(repo_name): | |
download_repo(repo_name, path, dry_run) | |
return None | |
def delete_repo(path, dry_run=False): | |
if dry_run: | |
print(f"[DRY RUN] Would delete repository at {path}") | |
else: | |
try: | |
shutil.rmtree(path) | |
print(f"Repository at {path} deleted successfully.") | |
except Exception as e: | |
print(f"Error deleting repository at {path}: {e}") | |
def get_max_vocab_size(repo_list): | |
max_vocab_size = 0 | |
repo_with_max_vocab = None | |
for repo in repo_list: | |
repo_name = repo[0].strip() | |
url = f"https://huggingface.co/{repo_name}/raw/main/config.json" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
config = response.json() | |
vocab_size = config.get("vocab_size", 0) | |
if vocab_size > max_vocab_size: | |
max_vocab_size = vocab_size | |
repo_with_max_vocab = repo_name | |
except requests.RequestException as e: | |
print(f"Error fetching data from {url}: {e}") | |
return max_vocab_size, repo_with_max_vocab | |
def download_json_files(repo_name, file_paths, output_dir): | |
base_url = f"https://huggingface.co/{repo_name}/raw/main/" | |
for file_path in file_paths: | |
url = base_url + file_path | |
response = requests.get(url) | |
if response.status_code == 200: | |
with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file: | |
file.write(response.content) | |
else: | |
print(f"Failed to download {file_path}") | |
def get_merged_path(filename, output_dir): | |
from datetime import datetime, timezone, timedelta | |
dt_now = datetime.now(timezone(timedelta(hours=9))) | |
basename = dt_now.strftime('Merged_%Y%m%d_%H%M') | |
ext = Path(filename).suffix | |
return str(Path(output_dir, basename + ext)), str(Path(output_dir, basename + ".yaml")) | |
def repo_list_to_yaml(repo_list_path, repo_list, output_yaml_path): | |
if Path(repo_list_path).suffix.lower() in (".yaml", ".yml"): | |
shutil.copy(repo_list_path, output_yaml_path) | |
else: | |
repos = list(repo_list) | |
yaml_dict = {} | |
yaml_dict.setdefault('models', {}) | |
for repo in repos: | |
model, weight, density = repo | |
model_info = {} | |
model_info['model'] = str(model) | |
model_info.setdefault('parameters', {}) | |
model_info['parameters']['weight'] = float(weight) | |
model_info['parameters']['density'] = float(density) | |
yaml_dict['models'][str(model.split("/")[-1])] = model_info | |
with open(output_yaml_path, mode='w', encoding='utf-8') as file: | |
yaml.dump(yaml_dict, file, default_flow_style=False, allow_unicode=True) | |
def process_repos(output_dir, base_model, staging_model, repo_list_file, p, lambda_val, skip_dirs, dry_run=False, progress=gr.Progress(track_tqdm=True)): | |
repo_type = "Default" # ("Default", "Files", "Diffusers") | |
# Check if output_dir exists | |
if os.path.exists(output_dir): | |
sys.exit(f"Output directory '{output_dir}' already exists. Exiting to prevent data loss.") | |
# Reset base and staging directories | |
reset_directories([base_model, staging_model], dry_run) | |
# Make sure staging and output directories exist | |
os.makedirs(base_model, exist_ok=True) | |
os.makedirs(staging_model, exist_ok=True) | |
repo_list_gen = repo_list_generator(repo_list_file, p, lambda_val) | |
repos_to_process = list(repo_list_gen) | |
# Initial download for 'base_model' | |
path = download_or_link_repo(repos_to_process[0][0].strip(), base_model, dry_run) | |
if path is not None and (".safetensors" in path or ".sft" in path): repo_type = "Files" | |
elif Path(base_model, "model_index.json").exists(): repo_type = "Diffusers" | |
if repo_type == "Files": | |
os.makedirs(output_dir, exist_ok=True) | |
output_file_path, output_yaml_path = get_merged_path(path, output_dir) | |
repo_list_to_yaml(repo_list_file, repo_list_gen, output_yaml_path) | |
for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Files')): | |
repo_name = repo[0].strip() | |
repo_p = repo[1] | |
repo_lambda = repo[2] | |
delete_repo(staging_model, dry_run) | |
staging_path = download_or_link_repo(repo_name, staging_model, dry_run) | |
do_merge_files(path, staging_path, output_file_path, repo_p, repo_lambda, dry_run) | |
reset_directories([base_model, staging_model], dry_run) | |
return output_file_path, output_yaml_path | |
elif repo_type == "Diffusers": | |
merge.copy_dirs(base_model, output_dir) | |
tensor_map = merge.map_tensors_to_files_diffusers(base_model, skip_dirs) | |
for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')): | |
repo_name = repo[0].strip() | |
repo_p = repo[1] | |
repo_lambda = repo[2] | |
delete_repo(staging_model, dry_run) | |
download_or_link_repo(repo_name, staging_model, dry_run) | |
tensor_map = do_merge_diffusers(tensor_map, staging_model, repo_p, repo_lambda, skip_dirs, dry_run) | |
os.makedirs(output_dir, exist_ok=True) | |
merge.copy_skipped_dirs(base_model, output_dir, skip_dirs) | |
merge.copy_nontensor_files(base_model, output_dir) | |
merge.save_tensor_map(tensor_map, output_dir) | |
reset_directories([base_model, staging_model], dry_run) | |
return None, None | |
elif repo_type == "Default": | |
merge.copy_dirs(base_model, output_dir) | |
tensor_map = merge.map_tensors_to_files(base_model) | |
for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')): | |
repo_name = repo[0].strip() | |
repo_p = repo[1] | |
repo_lambda = repo[2] | |
delete_repo(staging_model, dry_run) | |
download_or_link_repo(repo_name, staging_model, dry_run) | |
tensor_map = do_merge(tensor_map, staging_model, repo_p, repo_lambda, dry_run) | |
os.makedirs(output_dir, exist_ok=True) | |
merge.copy_nontensor_files(base_model, output_dir) | |
# Handle LLMs that add tokens by taking the largest | |
if os.path.exists(os.path.join(output_dir, 'config.json')): | |
max_vocab_size, repo_name = get_max_vocab_size(repos_to_process) | |
if max_vocab_size > 0: | |
file_paths = ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json'] | |
download_json_files(repo_name, file_paths, output_dir) | |
reset_directories([base_model, staging_model], dry_run) | |
merge.save_tensor_map(tensor_map, output_dir) | |
return None, None | |
if __name__ == "__main__": | |
args = parse_arguments() | |
skip_dirs = ['vae', 'text_encoder'] | |
process_repos(args.output_dir, args.base_model, args.staging_model, args.repo_list, args.p, args.lambda_val, skip_dirs, args.dry) | |