File size: 3,893 Bytes
47a3cb0
 
 
 
12a81fb
 
8f04382
47a3cb0
 
 
 
 
 
bcb5a3c
 
 
 
 
 
47a3cb0
0b329f4
47a3cb0
 
 
 
 
 
 
 
 
 
 
b29e176
47a3cb0
 
 
 
b29e176
 
47a3cb0
 
 
 
 
 
 
 
 
 
 
86fc5bd
47a3cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
12a81fb
47a3cb0
bcb5a3c
47a3cb0
 
 
 
 
 
 
 
bcb5a3c
 
 
47a3cb0
bcb5a3c
47a3cb0
 
4ecf6e4
 
 
 
 
 
 
86fc5bd
4ecf6e4
 
 
 
86fc5bd
4ecf6e4
 
 
47a3cb0
0b329f4
47a3cb0
 
a818142
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
app.py
An interactive demo for text-guided panorama generation.
"""
import os
from os.path import join
from PIL import Image
import torch
import gradio as gr

from syncdiffusion.syncdiffusion_model import SyncDiffusion
from syncdiffusion.utils import seed_everything

# set device
device = torch.device("cuda")

# load SyncDiffusion model
syncdiffusion = SyncDiffusion(device, sd_version="2.0")

def run_inference(
        prompt: str,  
        width: int = 2048,
        sync_weight: float = 20.0,
        sync_thres: int = 5,
        seed: int = 0
    ):
    # set random seed
    seed_everything(seed)

    img = syncdiffusion.sample_syncdiffusion(
        prompts = prompt,
        negative_prompts = "",
        height = 512,
        width = width,
        num_inference_steps = 50,
        guidance_scale = 7.5,
        sync_weight = sync_weight,
        sync_decay_rate = 0.99,
        sync_freq = 1,
        sync_thres = sync_thres,
        stride = 16
        )
    return [img]

if __name__=="__main__":
    title = "SyncDiffusion: Text-Guided Panorama Generation"

    description_text = '''
    This demo features text-guided panorama generation from our work <a href="https://arxiv.org/abs/2306.05178">SyncDiffusion: Coherent Montage via Synchronized Joint Diffusions, NeurIPS 2023</a>.  
    Please refer to our <a href="https://syncdiffusion.github.io/">project page</a> for details.
    (Note: Generation takes about 100 seconds for the default setting below, with the current A10G GPU.)
    '''

    # create UI        
    with gr.Blocks(title=title) as demo:

        # description of demo
        gr.Markdown(description_text)

        # inputs
        with gr.Row():
            with gr.Column():
                run_button = gr.Button(label="Generate")

                prompt = gr.Textbox(label="Text Prompt", value='a cinematic view of a castle in the sunset')
                width = gr.Slider(label="Width", minimum=512, maximum=3072, value=2048, step=128)
                sync_weight = gr.Slider(label="Sync Weight", minimum=0.0, maximum=30.0, value=20.0, step=5.0)
                sync_thres = gr.Slider(label="Sync Threshold (If N, apply SyncDiffusion for the first N steps)", minimum=0, maximum=15, value=5, step=1)
                seed = gr.Number(label="Seed", value=0)

            with gr.Column():
                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')

        # display examples
        examples = gr.Examples(
            examples=[
                ['a cinematic view of a castle in the sunset', 2048, 20.0, 5, 1],
                ['natural landscape in anime style illustration', 2048, 20.0, 5, 2],
                ['a photo of a lake under the northern lights', 2048, 20.0, 5, 6]
            ],
            inputs=[prompt, width, sync_weight, sync_thres, seed],
        )

        # display example images
        with gr.Row():
            gr.Image(Image.open(join("assets", "result_castle_seed_1.png")), label="Sample output 1")
        with gr.Row():
            gr.Textbox(label="Prompt", type="text", value="a cinematic view of a castle in the sunset")
        
        with gr.Row():
            gr.Image(Image.open(join("assets", "result_natural_seed_2.png")), label="Sample output 2")
        with gr.Row():
            gr.Textbox(label="Prompt", type="text", value="natural landscape in anime style illustration")
        
        with gr.Row():
            gr.Image(Image.open(join("assets", "result_northern_seed_6.png")), label="Sample output 3")
        with gr.Row():
            gr.Textbox(label="Prompt", type="text", value="a photo of a lake under the northern lights")

        ips = [prompt, width, sync_weight, sync_thres, seed]
        run_button.click(fn=run_inference, inputs=ips, outputs=[result_gallery])

    demo.queue(max_size=30)
    demo.launch()