Safetensors
FLUX.1-dev-fp8-flumina / lora_loading.py
aredden's picture
Fix issue where cublas linear not installed causing TypeError
56c313c
import re
from typing import Optional, OrderedDict, Tuple, TypeAlias, Union
import torch
from loguru import logger
from safetensors.torch import load_file
from tqdm import tqdm
from torch import nn
try:
from cublas_ops import CublasLinear
except Exception as e:
CublasLinear = type(None)
from float8_quantize import F8Linear
from modules.flux_model import Flux
path_regex = re.compile(r"/|\\")
StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
class LoraWeights:
def __init__(
self,
weights: StateDict,
path: str,
name: str = None,
scale: float = 1.0,
) -> None:
self.path = path
self.weights = weights
self.name = name if name else path_regex.split(path)[-1]
self.scale = scale
def swap_scale_shift(weight):
scale, shift = weight.chunk(2, dim=0)
new_weight = torch.cat([shift, scale], dim=0)
return new_weight
def check_if_lora_exists(state_dict, lora_name):
subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
for key in state_dict.keys():
if subkey in key:
return subkey
return False
def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name):
if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False:
weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k]
for key in weights_to_pop:
key_replacement = key.replace(
original_stubkey, flux_layer_name.replace(".weight", "")
)
new_state_dict[key_replacement] = state_dict.pop(key)
return new_state_dict, state_dict
else:
return new_state_dict, state_dict
def convert_diffusers_to_flux_transformer_checkpoint(
diffusers_state_dict,
num_layers,
num_single_layers,
has_guidance=True,
prefix="",
):
original_state_dict = {}
# time_text_embed.timestep_embedder -> time_in
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}time_text_embed.timestep_embedder.linear_1.weight",
"time_in.in_layer.weight",
)
# time_text_embed.text_embedder -> vector_in
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}time_text_embed.text_embedder.linear_1.weight",
"vector_in.in_layer.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}time_text_embed.text_embedder.linear_2.weight",
"vector_in.out_layer.weight",
)
if has_guidance:
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}time_text_embed.guidance_embedder.linear_1.weight",
"guidance_in.in_layer.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}time_text_embed.guidance_embedder.linear_2.weight",
"guidance_in.out_layer.weight",
)
# context_embedder -> txt_in
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}context_embedder.weight",
"txt_in.weight",
)
# x_embedder -> img_in
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}x_embedder.weight",
"img_in.weight",
)
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}norm1.linear.weight",
f"double_blocks.{i}.img_mod.lin.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}norm1_context.linear.weight",
f"double_blocks.{i}.txt_mod.lin.weight",
)
# Q, K, V
temp_dict = {}
expected_shape_qkv_a = None
expected_shape_qkv_b = None
expected_shape_add_qkv_a = None
expected_shape_add_qkv_b = None
dtype = None
device = None
for component in [
"to_q",
"to_k",
"to_v",
"add_q_proj",
"add_k_proj",
"add_v_proj",
]:
sample_component_A_key = (
f"{prefix}{block_prefix}attn.{component}.lora_A.weight"
)
sample_component_B_key = (
f"{prefix}{block_prefix}attn.{component}.lora_B.weight"
)
if (
sample_component_A_key in diffusers_state_dict
and sample_component_B_key in diffusers_state_dict
):
sample_component_A = diffusers_state_dict.pop(sample_component_A_key)
sample_component_B = diffusers_state_dict.pop(sample_component_B_key)
temp_dict[f"{component}"] = [sample_component_A, sample_component_B]
if expected_shape_qkv_a is None and not component.startswith("add_"):
expected_shape_qkv_a = sample_component_A.shape
expected_shape_qkv_b = sample_component_B.shape
dtype = sample_component_A.dtype
device = sample_component_A.device
if expected_shape_add_qkv_a is None and component.startswith("add_"):
expected_shape_add_qkv_a = sample_component_A.shape
expected_shape_add_qkv_b = sample_component_B.shape
dtype = sample_component_A.dtype
device = sample_component_A.device
else:
logger.info(
f"Skipping layer {i} since no LoRA weight is available for {sample_component_A_key}"
)
temp_dict[f"{component}"] = [None, None]
if device is not None:
if expected_shape_qkv_a is not None:
if (sq := temp_dict["to_q"])[0] is not None:
sample_q_A, sample_q_B = sq
else:
sample_q_A, sample_q_B = [
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
]
if (sq := temp_dict["to_k"])[0] is not None:
sample_k_A, sample_k_B = sq
else:
sample_k_A, sample_k_B = [
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
]
if (sq := temp_dict["to_v"])[0] is not None:
sample_v_A, sample_v_B = sq
else:
sample_v_A, sample_v_B = [
torch.zeros(expected_shape_qkv_a, dtype=dtype, device=device),
torch.zeros(expected_shape_qkv_b, dtype=dtype, device=device),
]
original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
)
original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
)
if expected_shape_add_qkv_a is not None:
if (sq := temp_dict["add_q_proj"])[0] is not None:
context_q_A, context_q_B = sq
else:
context_q_A, context_q_B = [
torch.zeros(
expected_shape_add_qkv_a, dtype=dtype, device=device
),
torch.zeros(
expected_shape_add_qkv_b, dtype=dtype, device=device
),
]
if (sq := temp_dict["add_k_proj"])[0] is not None:
context_k_A, context_k_B = sq
else:
context_k_A, context_k_B = [
torch.zeros(
expected_shape_add_qkv_a, dtype=dtype, device=device
),
torch.zeros(
expected_shape_add_qkv_b, dtype=dtype, device=device
),
]
if (sq := temp_dict["add_v_proj"])[0] is not None:
context_v_A, context_v_B = sq
else:
context_v_A, context_v_B = [
torch.zeros(
expected_shape_add_qkv_a, dtype=dtype, device=device
),
torch.zeros(
expected_shape_add_qkv_b, dtype=dtype, device=device
),
]
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
)
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
)
# qk_norm
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.norm_q.weight",
f"double_blocks.{i}.img_attn.norm.query_norm.scale",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.norm_k.weight",
f"double_blocks.{i}.img_attn.norm.key_norm.scale",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.norm_added_q.weight",
f"double_blocks.{i}.txt_attn.norm.query_norm.scale",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.norm_added_k.weight",
f"double_blocks.{i}.txt_attn.norm.key_norm.scale",
)
# ff img_mlp
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}ff.net.0.proj.weight",
f"double_blocks.{i}.img_mlp.0.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}ff.net.2.weight",
f"double_blocks.{i}.img_mlp.2.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}ff_context.net.0.proj.weight",
f"double_blocks.{i}.txt_mlp.0.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}ff_context.net.2.weight",
f"double_blocks.{i}.txt_mlp.2.weight",
)
# output projections
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.to_out.0.weight",
f"double_blocks.{i}.img_attn.proj.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}attn.to_add_out.weight",
f"double_blocks.{i}.txt_attn.proj.weight",
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear -> single_blocks.0.modulation.lin
key_norm = f"{prefix}{block_prefix}norm.linear.weight"
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
key_norm,
f"single_blocks.{i}.modulation.lin.weight",
)
has_q, has_k, has_v, has_mlp = False, False, False, False
shape_qkv_a = None
shape_qkv_b = None
# Q, K, V, mlp
q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
if q_A is not None and q_B is not None:
has_q = True
shape_qkv_a = q_A.shape
shape_qkv_b = q_B.shape
k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
if k_A is not None and k_B is not None:
has_k = True
shape_qkv_a = k_A.shape
shape_qkv_b = k_B.shape
v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
if v_A is not None and v_B is not None:
has_v = True
shape_qkv_a = v_A.shape
shape_qkv_b = v_B.shape
mlp_A = diffusers_state_dict.pop(
f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
)
mlp_B = diffusers_state_dict.pop(
f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
)
if mlp_A is not None and mlp_B is not None:
has_mlp = True
shape_qkv_a = mlp_A.shape
shape_qkv_b = mlp_B.shape
if any([has_q, has_k, has_v, has_mlp]):
if not has_q:
q_A, q_B = [
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
]
if not has_k:
k_A, k_B = [
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
]
if not has_v:
v_A, v_B = [
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
]
if not has_mlp:
mlp_A, mlp_B = [
torch.zeros(shape_qkv_a, dtype=dtype, device=device),
torch.zeros(shape_qkv_b, dtype=dtype, device=device),
]
original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
[q_A, k_A, v_A, mlp_A], dim=0
)
original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
[q_B, k_B, v_B, mlp_B], dim=0
)
# output projections
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}{block_prefix}proj_out.weight",
f"single_blocks.{i}.linear2.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}proj_out.weight",
"final_layer.linear.weight",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}proj_out.bias",
"final_layer.linear.bias",
)
original_state_dict, diffusers_state_dict = convert_if_lora_exists(
original_state_dict,
diffusers_state_dict,
f"{prefix}norm_out.linear.weight",
"final_layer.adaLN_modulation.1.weight",
)
if len(list(diffusers_state_dict.keys())) > 0:
logger.warning("Unexpected keys:", diffusers_state_dict.keys())
return original_state_dict
def convert_from_original_flux_checkpoint(original_state_dict: StateDict) -> StateDict:
"""
Convert the state dict from the original Flux checkpoint format to the new format.
Args:
original_state_dict (Dict[str, torch.Tensor]): The original Flux checkpoint state dict.
Returns:
Dict[str, torch.Tensor]: The converted state dict in the new format.
"""
sd = {
k.replace("lora_unet_", "")
.replace("double_blocks_", "double_blocks.")
.replace("single_blocks_", "single_blocks.")
.replace("_img_attn_", ".img_attn.")
.replace("_txt_attn_", ".txt_attn.")
.replace("_img_mod_", ".img_mod.")
.replace("_txt_mod_", ".txt_mod.")
.replace("_img_mlp_", ".img_mlp.")
.replace("_txt_mlp_", ".txt_mlp.")
.replace("_linear1", ".linear1")
.replace("_linear2", ".linear2")
.replace("_modulation_", ".modulation.")
.replace("lora_up", "lora_B")
.replace("lora_down", "lora_A"): v
for k, v in original_state_dict.items()
if "lora" in k
}
return sd
def get_module_for_key(
key: str, model: Flux
) -> F8Linear | torch.nn.Linear | CublasLinear:
parts = key.split(".")
module = model
for part in parts:
module = getattr(module, part)
return module
def get_lora_for_key(
key: str, lora_weights: dict
) -> Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]:
"""
Get LoRA weights for a specific key.
Args:
key (str): The key to look up in the LoRA weights.
lora_weights (dict): Dictionary containing LoRA weights.
Returns:
Optional[Tuple[torch.Tensor, torch.Tensor, Optional[float]]]: A tuple containing lora_A, lora_B, and alpha if found, None otherwise.
"""
prefix = key.split(".lora")[0]
lora_A = lora_weights.get(f"{prefix}.lora_A.weight")
lora_B = lora_weights.get(f"{prefix}.lora_B.weight")
alpha = lora_weights.get(f"{prefix}.alpha")
if lora_A is None or lora_B is None:
return None
return lora_A, lora_B, alpha
def get_module_for_key(
key: str, model: Flux
) -> F8Linear | torch.nn.Linear | CublasLinear:
parts = key.split(".")
module = model
for part in parts:
module = getattr(module, part)
return module
def calculate_lora_weight(
lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
rank: Optional[int] = None,
lora_scale: float = 1.0,
device: Optional[Union[torch.device, int, str]] = None,
):
lora_A, lora_B, alpha = lora_weights
if device is None:
device = lora_A.device
uneven_rank = lora_B.shape[1] != lora_A.shape[0]
rank_diff = lora_A.shape[0] / lora_B.shape[1]
if rank is None:
rank = lora_B.shape[1]
if alpha is None:
alpha = rank
dtype = torch.float32
w_up = lora_A.to(dtype=dtype, device=device)
w_down = lora_B.to(dtype=dtype, device=device)
if alpha != rank:
w_up = w_up * alpha / rank
if uneven_rank:
# Fuse each lora instead of repeat interleave for each individual lora,
# seems to fuse more correctly.
fused_lora = torch.zeros(
(lora_B.shape[0], lora_A.shape[1]), device=device, dtype=dtype
)
w_up = w_up.chunk(int(rank_diff), dim=0)
for w_up_chunk in w_up:
fused_lora = fused_lora + (lora_scale * torch.mm(w_down, w_up_chunk))
else:
fused_lora = lora_scale * torch.mm(w_down, w_up)
return fused_lora
@torch.inference_mode()
def unfuse_lora_weight_from_module(
fused_weight: torch.Tensor,
lora_weights: dict,
rank: Optional[int] = None,
lora_scale: float = 1.0,
):
w_dtype = fused_weight.dtype
dtype = torch.float32
device = fused_weight.device
fused_weight = fused_weight.to(dtype=dtype, device=device)
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
module_weight = fused_weight - fused_lora
return module_weight.to(dtype=w_dtype, device=device)
@torch.inference_mode()
def apply_lora_weight_to_module(
module_weight: torch.Tensor,
lora_weights: dict,
rank: int = None,
lora_scale: float = 1.0,
):
w_dtype = module_weight.dtype
dtype = torch.float32
device = module_weight.device
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
fused_weight = module_weight.to(dtype=dtype) + fused_lora
return fused_weight.to(dtype=w_dtype, device=device)
def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
check_if_starts_with_transformer = [
k for k in lora_weights.keys() if k.startswith("transformer.")
]
if len(check_if_starts_with_transformer) > 0:
lora_weights = convert_diffusers_to_flux_transformer_checkpoint(
lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
)
else:
lora_weights = convert_from_original_flux_checkpoint(lora_weights)
logger.info("LoRA weights loaded")
logger.debug("Extracting keys")
keys_without_ab = list(
set(
[
key.replace(".lora_A.weight", "")
.replace(".lora_B.weight", "")
.replace(".lora_A", "")
.replace(".lora_B", "")
.replace(".alpha", "")
for key in lora_weights.keys()
]
)
)
logger.debug("Keys extracted")
return keys_without_ab, lora_weights
def get_lora_weights(lora_path: str | StateDict):
if isinstance(lora_path, (dict, LoraWeights)):
return lora_path, True
else:
return load_file(lora_path, "cpu"), False
def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]):
dtype = linear.weight.dtype
weight_is_f8 = False
if isinstance(linear, F8Linear):
weight_is_f8 = True
weight = (
linear.float8_data.clone()
.detach()
.float()
.mul(linear.scale_reciprocal)
.to(linear.weight.device)
)
elif isinstance(linear, torch.nn.Linear):
weight = linear.weight.clone().detach().float()
elif isinstance(linear, CublasLinear) and CublasLinear != type(None):
weight = linear.weight.clone().detach().float()
return weight, weight_is_f8, dtype
@torch.inference_mode()
def apply_lora_to_model(
model: Flux,
lora_path: str | StateDict,
lora_scale: float = 1.0,
return_lora_resolved: bool = False,
) -> Flux:
has_guidance = model.params.guidance_embed
logger.info(f"Loading LoRA weights for {lora_path}")
lora_weights, already_loaded = get_lora_weights(lora_path)
if not already_loaded:
keys_without_ab, lora_weights = resolve_lora_state_dict(
lora_weights, has_guidance
)
elif isinstance(lora_weights, LoraWeights):
b_ = lora_weights
lora_weights = b_.weights
keys_without_ab = list(
set(
[
key.replace(".lora_A.weight", "")
.replace(".lora_B.weight", "")
.replace(".lora_A", "")
.replace(".lora_B", "")
.replace(".alpha", "")
for key in lora_weights.keys()
]
)
)
else:
lora_weights = lora_weights
keys_without_ab = list(
set(
[
key.replace(".lora_A.weight", "")
.replace(".lora_B.weight", "")
.replace(".lora_A", "")
.replace(".lora_B", "")
.replace(".alpha", "")
for key in lora_weights.keys()
]
)
)
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
module = get_module_for_key(key, model)
weight, is_f8, dtype = extract_weight_from_linear(module)
lora_sd = get_lora_for_key(key, lora_weights)
if lora_sd is None:
# Skipping LoRA application for this module
continue
weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
if is_f8:
module.set_weight_tensor(weight.type(dtype))
else:
module.weight.data = weight.type(dtype)
logger.success("Lora applied")
if return_lora_resolved:
return model, lora_weights
return model
def remove_lora_from_module(
model: Flux,
lora_path: str | StateDict,
lora_scale: float = 1.0,
):
has_guidance = model.params.guidance_embed
logger.info(f"Loading LoRA weights for {lora_path}")
lora_weights, already_loaded = get_lora_weights(lora_path)
if not already_loaded:
keys_without_ab, lora_weights = resolve_lora_state_dict(
lora_weights, has_guidance
)
elif isinstance(lora_weights, LoraWeights):
b_ = lora_weights
lora_weights = b_.weights
keys_without_ab = list(
set(
[
key.replace(".lora_A.weight", "")
.replace(".lora_B.weight", "")
.replace(".lora_A", "")
.replace(".lora_B", "")
.replace(".alpha", "")
for key in lora_weights.keys()
]
)
)
lora_scale = b_.scale
else:
lora_weights = lora_weights
keys_without_ab = list(
set(
[
key.replace(".lora_A.weight", "")
.replace(".lora_B.weight", "")
.replace(".lora_A", "")
.replace(".lora_B", "")
.replace(".alpha", "")
for key in lora_weights.keys()
]
)
)
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
module = get_module_for_key(key, model)
weight, is_f8, dtype = extract_weight_from_linear(module)
lora_sd = get_lora_for_key(key, lora_weights)
if lora_sd is None:
# Skipping LoRA application for this module
continue
weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
if is_f8:
module.set_weight_tensor(weight.type(dtype))
else:
module.weight.data = weight.type(dtype)
logger.success("Lora unfused")
return model