import gradio as gr from PIL import Image import torch from muse import PipelineMuse from diffusers import AutoPipelineForText2Image, UniPCMultistepScheduler muse_512 = PipelineMuse.from_pretrained("openMUSE/muse-512").to("cuda", dtype=torch.float16) muse_512.transformer.enable_xformers_memory_efficient_attention() muse_512_fine = PipelineMuse.from_pretrained("openMUSE/muse-512-finetuned").to("cuda", dtype=torch.float16) muse_512_fine.transformer.enable_xformers_memory_efficient_attention() sdv1_5 = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16").to("cuda", dtype=torch.float16) sdv1_5.scheduler = UniPCMultistepScheduler.from_config(sdv1_5.scheduler.config) sdv1_5.enable_xformers_memory_efficient_attention() def infer(prompt, negative): print("Generating:") muse_512_image = muse_512( prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0), )[0] muse_512_fine_image = muse_512_fine( prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0), )[0] sdv1_5_image = sdv1_5(prompt, num_inference_steps=25).images[0] images = [muse_512_image, muse_512_fine_image, sdv1_5_image] return images examples = [ [ 'A high tech solarpunk utopia in the Amazon rainforest', 'low quality', 10, ], [ 'A pikachu fine dining with a view to the Eiffel Tower', 'low quality', 10, ], [ 'A mecha robot in a favela in expressionist style', 'low quality, 3d, photorealistic', 10, ], [ 'an insect robot preparing a delicious meal', 'low quality, illustration', 10, ], [ "A small cabin on top of a snowy mountain in the style of Disney, artstation", 'low quality, ugly', 10, ], ] css = """ h1 { text-align: center; } #component-0 { max-width: 730px; margin: auto; } """ block = gr.Blocks(css=css) with block: gr.Markdown("MUSE is an upcoming fast text2image model.") with gr.Group(): with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): with gr.Column(): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) btn = gr.Button("Generate image", scale=0) gallery = gr.Gallery( label="Generated images", show_label=False, ).style() with gr.Accordion("Advanced settings", open=False): guidance_scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=20, value=10, step=0.1 ) ex = gr.Examples(examples=examples, fn=infer, inputs=[text, negative, guidance_scale], outputs=gallery, cache_examples=False) ex.dataset.headers = [""] text.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) btn.click(infer, inputs=[text, negative, guidance_scale], outputs=gallery) block.launch()