File size: 3,938 Bytes
3304f7d
ddc8a59
3304f7d
 
ddc8a59
 
 
b64b5e6
ddc8a59
 
 
0554219
ddc8a59
 
2b2693f
 
 
 
 
 
 
ddc8a59
 
 
b64b5e6
ddc8a59
 
3304f7d
ddc8a59
 
3304f7d
ddc8a59
3304f7d
6c04d23
33871e9
 
 
ddc8a59
 
 
a1008bd
 
 
 
 
 
3304f7d
 
 
a1008bd
3304f7d
ddc8a59
 
 
 
3304f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr

from convert import run_conversion
from hub_utils import push_to_hub, save_model_card

PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
DESCRIPTION = """
This Space lets you convert KerasCV Stable Diffusion weights to a format compatible with [Diffusers](https://github.com/huggingface/diffusers) 🧨. This allows users to fine-tune using KerasCV and use the fine-tuned weights in Diffusers taking advantage of its nifty features (like [schedulers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers), [fast attention](https://huggingface.co/docs/diffusers/optimization/fp16), etc.). Specifically, the Keras weights are first converted to PyTorch and then they are wrapped into a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). This pipeline is then pushed to the Hugging Face Hub given you have provided `your_hf_token`.

## Notes (important)

* The Space downloads a couple of pre-trained weights and runs a dummy inference. Depending, on the machine type, the enture process can take anywhere between 2 - 5 minutes.
* Only Stable Diffusion (v1) is supported as of now. In particular this checkpoint: [`"CompVis/stable-diffusion-v1-4"`](https://huggingface.co/CompVis/stable-diffusion-v1-4).
* [This Colab Notebook](https://colab.research.google.com/drive/1RYY077IQbAJldg8FkK8HSEpNILKHEwLb?usp=sharing) was used to develop the conversion utilities initially.
* Providing both `text_encoder_weights` and `unet_weights` is dependent on the fine-tuning task. Here are some _typical_ scenarios:
    
    * [DreamBooth](https://dreambooth.github.io/): Both text encoder and UNet
    * [Textual Inversion](https://textual-inversion.github.io/): Text encoder
    * [Traditional text2image fine-tuning](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image): UNet
    
    **In case none of the `text_encoder_weights` and `unet_weights` is provided, nothing will be done.**
* When providing the weights' links, ensure they're directly downloadable. Internally, the Space uses [`tf.keras.utils.get_file()`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/get_file) to retrieve the weights locally. 
* If you don't provide `your_hf_token` the converted pipeline won't be pushed. 

Check [here](https://github.com/huggingface/diffusers/blob/31be42209ddfdb69d9640a777b32e9b5c6259bf0/examples/dreambooth/train_dreambooth_lora.py#L975) for an example on how you can change the scheduler of an already initialized `StableDiffusionPipeline`.
"""


def run(hf_token, text_encoder_weights, unet_weights, repo_prefix):
    if text_encoder_weights == "":
        text_encoder_weights = None
    if unet_weights == "":
        unet_weights = None

    if text_encoder_weights is None and unet_weights is None:
        return "❌ No fine-tuned weights provided, nothing to do."

    pipeline = run_conversion(text_encoder_weights, unet_weights)
    output_path = "kerascv_sd_diffusers_pipeline"
    pipeline.save_pretrained(output_path)

    weight_paths = []
    if text_encoder_weights is not None:
        weight_paths.append(text_encoder_weights)
    if unet_weights is not None:
        weight_paths.append(unet_weights)
    save_model_card(
        base_model=PRETRAINED_CKPT,
        repo_folder=output_path,
        weight_paths=weight_paths,
    )
    push_str = push_to_hub(hf_token, output_path, repo_prefix)
    return push_str


demo = gr.Interface(
    title="KerasCV Stable Diffusion to Diffusers Stable Diffusion Pipelines 🧨🤗",
    description=DESCRIPTION,
    allow_flagging="never",
    inputs=[
        gr.Text(max_lines=1, label="your_hf_token"),
        gr.Text(max_lines=1, label="text_encoder_weights"),
        gr.Text(max_lines=1, label="unet_weights"),
        gr.Text(max_lines=1, label="output_repo_prefix"),
    ],
    outputs=[gr.Markdown(label="output")],
    fn=run,
)

demo.launch()