File size: 2,156 Bytes
997984a
 
 
 
fbaddc2
 
 
 
 
 
 
 
997984a
 
 
 
 
 
 
fbaddc2
997984a
11b4f29
 
 
49c1726
fbaddc2
 
 
 
 
 
997984a
fbaddc2
997984a
 
 
 
 
 
 
fbaddc2
 
997984a
fbaddc2
997984a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbaddc2
997984a
fbaddc2
 
997984a
fbaddc2
997984a
97e8bbe
997984a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.diffusion import (
    DiffusersGenerationAlgorithm,
    DDPMGenerator,
    DDIMGenerator,
    ScoreSdeGenerator,
    LDMTextToImageGenerator,
    LDMGenerator,
    StableDiffusionGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def run_inference(model_type: str, prompt: str):

    if prompt == "":
        config = eval(f"{model_type}()")
    else:
        config = eval(f'{model_type}(prompt="{prompt}")')
    if config.modality != "token2image" and prompt != "":
        raise ValueError(
            f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})"
        )
    model = DiffusersGenerationAlgorithm(config)
    image = list(model.sample(1))[0]

    return image


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_application"]
        for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos))
    ]
    algos = [a for a in algos if not "GeoDiff" in a]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
        ""
    )

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    demo = gr.Interface(
        fn=run_inference,
        title="Diffusion-based image generators",
        inputs=[
            gr.Dropdown(
                algos, label="Diffusion model", value="StableDiffusionGenerator"
            ),
            gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1),
        ],
        outputs=gr.Image(type="pil"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)