File size: 3,340 Bytes
2ec2ebd
 
e2d57c7
2ec2ebd
d415ad5
fe9c201
d415ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2d57c7
d415ad5
740a0ae
c3eb335
d415ad5
 
 
 
 
740a0ae
2ec2ebd
e18f8f6
d415ad5
 
 
 
 
 
 
ba74c9e
d415ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0a063
 
d415ad5
304f5cd
 
 
 
 
 
 
 
 
 
 
 
 
c3eb335
d415ad5
c3eb335
3e50201
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from diffusion_lens import get_images
import numpy as np

MAX_SEED = np.iinfo(np.int32).max

# Description
title = r"""
<h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1>
"""

description = r"""
<b>Based on the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>InstantStyle: Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br>
"""

article = r"""
---
πŸ“ **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
@article{toker2024diffusion,
  title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines},
  author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan},
  journal={arXiv preprint arXiv:2403.05846},
  year={2024}
}
}
```
πŸ“§ **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>tok@cs.technuin.ac.il</b>.
"""


model_num_of_layers = {
    'Stable Diffusion 1.4': 12,
    'Stable Diffusion 2.1': 22,
}

def generate_images(prompt, model, seed):
    seed = random.randint(0, MAX_SEED) if seed == -1 else seed
    print('calling diffusion lens with model:', model, 'and seed:', seed)
    gr.Info('Generating images from intermediate layers..')
    all_images = []  # Initialize a list to store all images
    max_num_of_layers = model_num_of_layers[model]
    for skip_layers in range(max_num_of_layers, -1, -1):
        # Pass the model and seed to the get_images function
        images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed)
        all_images.append((images[0], f'layer_{12 - skip_layers}'))
        yield all_images

with gr.Blocks() as demo:
    
    gr.Markdown(title)
    gr.Markdown(description)
    
    # text_input = gr.Textbox(label="Enter prompt")
    model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2'])
    seed_input = gr.Number(label="Enter Seed", value=0)  # Default seed set to 0
    gallery = gr.Gallery(label="Generated Images", columns=6, rows=2, object_fit="contain", height="auto")
    # Update the submit function to include the new inputs

    
    # text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery)

    with gr.Column():
        prompt = gr.Textbox(
            label="Prompt",
            value="a cat, masterpiece, best quality, high quality",
        )

    model = gr.Radio(
        [
            "Stable Diffusion 1.4",
            "Stable Diffusion 2.1",
        ],
        value="Stable Diffusion 1.4",
        label="Model",
    )
    
    seed = gr.Slider(
        minimum=-1,
        maximum=MAX_SEED,
        value=-1,
        step=1,
        label="Seed Value",
    )

    inputs = [
        prompt,
        model,
        seed,
    ]
    outputs = [gallery]

    generate_button = gr.Button("Generate Image")

    gr.on(
        triggers=[
            prompt.submit,
            generate_button.click,
            seed.input,
            model.input
        ],
        fn=generate_images,
        inputs=inputs,
        outputs=outputs,
        show_progress="full",
        show_api=False,
        trigger_mode="always_last",
        )

    gr.Markdown(article)

demo.queue(api_open=False)
demo.launch(show_api=False)