import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.generation.diffusion import ( DiffusersGenerationAlgorithm, DDPMGenerator, DDIMGenerator, ScoreSdeGenerator, LDMTextToImageGenerator, LDMGenerator, StableDiffusionGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def run_inference(model_type: str, prompt: str): if prompt == "": config = eval(f"{model_type}()") else: config = eval(f'{model_type}(prompt="{prompt}")') if config.modality != "token2image" and prompt != "": raise ValueError( f"{model_type} is an unconditional generative model, please remove prompt (not={prompt})" ) model = DiffusersGenerationAlgorithm(config) image = list(model.sample(1))[0] return image if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() algos = [ x["algorithm_application"] for x in list(filter(lambda x: "Diff" in x["algorithm_name"], all_algos)) ] algos = [a for a in algos if not "GeoDiff" in a] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna( "" ) with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=run_inference, title="Diffusion-based image generators", inputs=[ gr.Dropdown( algos, label="Diffusion model", value="StableDiffusionGenerator" ), gr.Textbox(label="Text prompt", placeholder="A blue tree", lines=1), ], outputs=gr.Image(type="pil"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)