flux_swap / README.md
Aryanne's picture
Update README.md
76aaa79 verified
metadata
base_model:
  - black-forest-labs/FLUX.1-dev
  - black-forest-labs/FLUX.1-schnell
language:
  - en
license: other
license_name: flux-1-dev-non-commercial-license
license_link: LICENSE.md
tags:
  - merge
  - flux

Aryanne/flux_swap

This model is a merge of black-forest-labs/FLUX.1-dev and black-forest-labs/FLUX.1-schnell.

But different than others methods here the values in the tensors are not changed but substitute in a checkboard pattern with the values of FLUX.1-schnell, so ~50% of each is present here.(if my code is right)

from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch
import gc




with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config)

dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")

dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))

def swapping_method(base, x, parameters):
    def swap_values(shape, n, base, x):
        if x.dim() == 2:
           rows, cols = shape
           rows_range = torch.arange(rows).view(-1, 1)
           cols_range = torch.arange(cols).view(1, -1)
           mask = ((rows_range + cols_range) % n == 0).to(base.device.type).bool()
           x = torch.where(mask, x, base)
        else:
           rows_range = torch.arange(shape[0])
           mask = ((rows_range) % n == 0).to(base.device.type).bool()
           x = torch.where(mask, x, base)
        return x

    def rand_mask(base, x, percent, seed=None):
        oldseed = torch.seed()
        if seed is not None:
            torch.manual_seed(seed)
        random = torch.rand(base.shape)
        mask = (random <= percent).to(base.device.type).bool()
        del random
        torch.manual_seed(oldseed)
        x = torch.where(mask, x, base) 
        return x
    
   
    if x.device.type == "cpu":
         x = x.to(torch.bfloat16)
         base = base.to(torch.bfloat16)

    diagonal_offset = None
    diagonal_offset = parameters.get('diagonal_offset')
    random_mask = parameters.get('random_mask')
    random_mask_seed = parameters.get('random_mask_seed')
    random_mask_seed = int(random_mask_seed) if random_mask_seed is not None else random_mask_seed

    assert (diagonal_offset is not None) and (diagonal_offset % 1 == 0) and (diagonal_offset >= 2), "The diagonal_offset must be an integer greater than or equal to 2."
        
    if random_mask != 0.0:
       assert (random_mask is not None) and (random_mask < 1.0) and (random_mask > 0.0) , "The random_mask parameter can't be empty, 0, 1, or None, it must be a number between 0 and 1."
       assert random_mask_seed is None or (isinstance(random_mask_seed, int) and random_mask_seed % 1 == 0), "The random_mask_seed parameter must be None or an integer, None is a random seed."
       x = rand_mask(base, x, random_mask, random_mask_seed)

    else:
       if parameters.get('invert_offset') == False:
           x = swap_values(x.shape, diagonal_offset, base, x)
       else:
           x = swap_values(x.shape, diagonal_offset, x, base)

    del base
    return x

parameters = {
    'diagonal_offset': 2,      
    'random_mask': False,
    'invert_offset': False,       
   # 'random_mask_seed': "899557"
}








merged_state_dict = {}
guidance_state_dict = {}

for i in range(len((dev_shards))):
    state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
    state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])

    keys = list(state_dict_dev_temp.keys())
    for k in keys:
        if "guidance" not in k:
            merged_state_dict[k] = swapping_method(state_dict_dev_temp.pop(k),state_dict_schnell_temp.pop(k), parameters)
        else:
            guidance_state_dict[k] = state_dict_dev_temp.pop(k)

    if len(state_dict_dev_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
    if len(state_dict_schnell_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")




merged_state_dict.update(guidance_state_dict)
load_model_dict_into_meta(model, merged_state_dict)

model.to(torch.bfloat16).save_pretrained("merged-flux")

Used a piece of this code from mergekit

Thanks SayakPaul for your code which helped me do this merge.