import argparse import numpy as np import os import shutil import torch import torch.nn.functional as F from safetensors.torch import safe_open, save_file import glob from pathlib import Path def merge_tensors(tensor1, tensor2, p): # Calculate the delta of the weights delta = tensor2 - tensor1 # Generate the mask m^t from Bernoulli distribution m = torch.from_numpy(np.random.binomial(1, p, delta.shape)).to(tensor1.dtype) # Apply the mask to the delta to get δ̃^t delta_tilde = m * delta # Scale the masked delta by the dropout rate to get δ̂^t delta_hat = delta_tilde / (1 - p) return delta_hat def merge_safetensors(file_path1, file_path2, p, lambda_val): merged_tensors = {} with safe_open(file_path1, framework="pt", device="cpu") as f1, safe_open(file_path2, framework="pt", device="cpu") as f2: keys1 = set(f1.keys()) keys2 = set(f2.keys()) common_keys = keys1.intersection(keys2) for key in common_keys: tensor1 = f1.get_tensor(key) tensor2 = f2.get_tensor(key) tensor1, tensor2 = resize_tensors(tensor1, tensor2) merged_tensors[key] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) print("merging", key) return merged_tensors class BinDataHandler(): def __init__(self, data): self.data = data def get_tensor(self, key): return self.data[key] def read_tensors(file_path, ext): if ext == ".safetensors" and (file_path.endswith(".safetensors") or file_path.endswith(".sft")): print(f"Reading tensors from {file_path} in {ext} format.") f = safe_open(file_path, framework="pt", device="cpu") return f, set(f.keys()) if ext == ".bin" and file_path.endswith(".bin"): print(f"Reading tensors from {file_path} in {ext} format.") data = torch.load(file_path, map_location=torch.device('cpu')) f = BinDataHandler(data) return f, set(data.keys()) return None, None def resize_tensors(tensor1, tensor2): if len(tensor1.shape) not in [1, 2]: return tensor1, tensor2 if len(tensor1.shape) == 1 and len(tensor2.shape) == 1: if tensor1.shape[-1] < tensor2.shape[-1]: padding_size = tensor2.shape[-1] - tensor1.shape[-1] pad = torch.nn.ConstantPad1d((padding_size, 0), 0) tensor1 = pad(tensor1) elif tensor2.shape[-1] < tensor1.shape[-1]: padding_size = tensor1.shape[-1] - tensor2.shape[-1] pad = torch.nn.ConstantPad1d((padding_size, 0), 0) tensor2 = pad(tensor2) else: # Pad along the last dimension (width) if tensor1.shape[-1] < tensor2.shape[-1]: padding_size = tensor2.shape[-1] - tensor1.shape[-1] tensor1 = F.pad(tensor1, (0, padding_size, 0, 0)) elif tensor2.shape[-1] < tensor1.shape[-1]: padding_size = tensor1.shape[-1] - tensor2.shape[-1] tensor2 = F.pad(tensor2, (0, padding_size, 0, 0)) # Pad along the first dimension (height) if tensor1.shape[0] < tensor2.shape[0]: padding_size = tensor2.shape[0] - tensor1.shape[0] tensor1 = F.pad(tensor1, (0, 0, 0, padding_size)) elif tensor2.shape[0] < tensor1.shape[0]: padding_size = tensor1.shape[0] - tensor2.shape[0] tensor2 = F.pad(tensor2, (0, 0, 0, padding_size)) return tensor1, tensor2 def merge_folder(tensor_map, directory_path, p, lambda_val): keys1 = set(tensor_map.keys()) # Some repos have both bin and safetensors, choose safetensors if so ext = None for filename in glob.glob(f'{directory_path}/**', recursive=True): filename = os.path.normpath(filename) # Default to safetensors if filename.endswith(".safetensors") or filename.endswith(".sft"): ext = ".safetensors" if filename.endswith(".bin") and ext is None: ext = ".bin" if ext is None: raise "Could not find model files" for filename in glob.glob(f'{directory_path}/**', recursive=True): filename = os.path.normpath(filename) f2, keys2 = read_tensors(filename, ext) if keys2: common_keys = keys1.intersection(keys2) for key in common_keys: if "block_sparse_moe.gate" in key: tensor1 = tensor_map[key]['tensor'] tensor2 = f2.get_tensor(key) tensor_map[key]['tensor'] = (tensor1 + tensor2) /2.0 print("merging", key) continue tensor1 = tensor_map[key]['tensor'] tensor2 = f2.get_tensor(key) tensor1, tensor2 = resize_tensors(tensor1, tensor2) tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) print("merging", key) return tensor_map def merge_folder_diffusers(tensor_map, directory_path, p, lambda_val, skip_dirs): keys1 = set(tensor_map.keys()) # Some repos have both bin and safetensors, choose safetensors if so ext = None for filename in [p for p in glob.glob(f'{directory_path}/*', recursive=False) if ".fp16." not in p]: filename = os.path.normpath(filename) # Default to safetensors if filename.endswith(".safetensors") or filename.endswith(".sft"): ext = ".safetensors" if filename.endswith(".bin") and ext is None: ext = ".bin" if ext is None: raise "Could not find model files" for dirname in glob.glob(f'{directory_path}/*/', recursive=False): if Path(dirname).stem in skip_dirs: continue for filename in [p for p in glob.glob(f'{dirname}/*', recursive=False) if ".fp16." not in p]: filename = os.path.normpath(filename) f2, keys2 = read_tensors(filename, ext) if keys2: common_keys = keys1.intersection(keys2) for key in common_keys: if "block_sparse_moe.gate" in key: tensor1 = tensor_map[key]['tensor'] tensor2 = f2.get_tensor(key) tensor_map[key]['tensor'] = (tensor1 + tensor2) /2.0 print("merging", key) continue tensor1 = tensor_map[key]['tensor'] tensor2 = f2.get_tensor(key) tensor1, tensor2 = resize_tensors(tensor1, tensor2) tensor_map[key]['tensor'] = tensor1 + lambda_val * merge_tensors(tensor1, tensor2, p) print("merging", key) return tensor_map def merge_files(base_model, second_model, output_model, p, lambda_val): merged = merge_safetensors(base_model, second_model, p, lambda_val) save_file(merged, output_model) def map_tensors_to_files(directory_path): tensor_map = {} for filename in glob.glob(f'{directory_path}/**', recursive=True): filename = os.path.normpath(filename) f, keys = read_tensors(filename, '.safetensors') if keys: for key in keys: tensor = f.get_tensor(key) tensor_map[key] = {'filename':filename, 'shape':tensor.shape, 'tensor': tensor} return tensor_map def map_tensors_to_files_diffusers(directory_path, skip_dirs): tensor_map = {} for dirname in glob.glob(f'{directory_path}/*/', recursive=False): if Path(dirname).stem in skip_dirs: continue for filename in [p for p in glob.glob(f'{dirname}/*', recursive=False) if ".fp16." not in p]: filename = os.path.normpath(filename) f, keys = read_tensors(filename, '.safetensors') if keys: for key in keys: tensor = f.get_tensor(key) tensor_map[key] = {'filename':filename, 'shape':tensor.shape, 'tensor': tensor} return tensor_map def copy_nontensor_files(from_path, to_path): print(f"Copying non-tensor files {from_path} to {to_path}") shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*"), dirs_exist_ok=True) def copy_skipped_dirs(from_path, to_path, skip_dirs): for dirname in glob.glob(f'{from_path}/*/', recursive=False): if Path(dirname).stem in skip_dirs: dirname = os.path.normpath(dirname) print(f"Copying skipped files {dirname} to {to_path}") shutil.copytree(Path(dirname).resolve(), Path(to_path, Path(dirname).stem).resolve(), ignore=shutil.ignore_patterns(".*", "README*"), dirs_exist_ok=True) def save_tensor_map(tensor_map, output_folder): metadata = {'format': 'pt'} by_filename = {} for key, value in tensor_map.items(): filename = value["filename"] tensor = value["tensor"] filename = os.path.normpath(filename) if filename not in by_filename: by_filename[filename] = {} by_filename[filename][key] = tensor for filename in sorted(by_filename.keys()): filename = os.path.normpath(filename) if Path(output_folder, Path(filename).parent.name).exists(): output_file = str(Path(output_folder, Path(filename).parent.name, Path(filename).name)) else: output_file = str(Path(output_folder, Path(filename).name)) print("Saving:", output_file) save_file(by_filename[filename], output_file, metadata=metadata) def copy_dirs(src: str, dst: str): shutil.copytree(src, dst, ignore=shutil.ignore_patterns("*.*"), dirs_exist_ok=True) def main(): # Parse command-line arguments parser = argparse.ArgumentParser(description='Merge two safetensor model files.') parser.add_argument('base_model', type=str, help='The base model safetensor file') parser.add_argument('second_model', type=str, help='The second model safetensor file') parser.add_argument('output_model', type=str, help='The output merged model safetensor file') parser.add_argument('-p', type=float, default=0.5, help='Dropout probability') parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta') args = parser.parse_args() skip_dirs = ['vae', 'text_encoder'] if os.path.isdir(args.base_model): if not os.path.exists(args.output_model): os.makedirs(args.output_model) if os.path.exists(args.base_model + "/model_index.json"): # assume Diffusers Repo copy_dirs(args.base_model, args.output_model) tensor_map = map_tensors_to_files_diffusers(args.base_model, skip_dirs) tensor_map = merge_folder_diffusers(tensor_map, args.second_model, args.p, args.lambda_val, skip_dirs) copy_skipped_dirs(args.base_model, args.output_model, skip_dirs) copy_nontensor_files(args.base_model, args.output_model) save_tensor_map(tensor_map, args.output_model) else: copy_dirs(args.base_model, args.output_model) tensor_map = map_tensors_to_files(args.base_model) tensor_map = merge_folder(tensor_map, args.second_model, args.p, args.lambda_val) copy_nontensor_files(args.base_model, args.output_model) save_tensor_map(tensor_map, args.output_model) else: merged = merge_safetensors(args.base_model, args.second_model, args.p, args.lambda_val) save_file(merged, args.output_model) if __name__ == '__main__': main()