File size: 2,963 Bytes
14989e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
---
license: mit
library_name: diffusers
---
# flux-uncensored-nf4
## Summary
Flux base model merged with uncensored LoRA, quantized to NF4. This model is not for those looking for "safe" or watered-down outputs. It’s optimized for real-world use with fewer constraints and lower VRAM requirements, thanks to NF4 quantization.
## Specs
* Model: Flux base
* LoRA: Uncensored version, merged directly
* Quantization: NF4 format for speed and VRAM efficiency
## Usage
Not so much for plug-and-play model, but pretty straight forward (script from sayak [https://github.com/huggingface/diffusers/issues/9165#issue-2462431761])
Please install pip install -U bitsandbytes to proceed.
```python
"""
Some bits are from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py
"""
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
from diffusers import FluxTransformer2DModel, FluxPipeline
import safetensors.torch
import gc
import torch
dtype = torch.bfloat16
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
ckpt_path = hf_hub_download("shauray/flux.1-dev-uncensored-nf4", filename="diffusion_pytorch_model.safetensors")
original_state_dict = safetensors.torch.load_file(ckpt_path)
with init_empty_weights():
config = FluxTransformer2DModel.load_config("shauray/flux.1-dev-uncensored-nf4")
model = FluxTransformer2DModel.from_config(config).to(dtype)
expected_state_dict_keys = list(model.state_dict().keys())
_replace_with_bnb_linear(model, "nf4")
for param_name, param in original_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(dtype)
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(
model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
)
del original_state_dict
gc.collect()
print(compute_module_sizes(model)[""] / 1024 / 1204)
pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
prompt = "A mystic cat with a sign that says hello world!"
image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
image.save("flux-nf4-dev-loaded.png")
```
this README has what you'd need, it's a merge from [Uncensored LoRA on CivitAI]([https://civitai.com/models/875879/flux-lustlyai-uncensored-v1-nsfw-lora-with-male-and-female-nudity)
|