FreeU / C
ChenyangSi's picture
Upload C
db2d287
raw
history blame
No virus
3.3 kB
import gradio as gr
from PIL import Image
import torch
from muse import PipelineMuse
from diffusers import AutoPipelineForText2Image, UniPCMultistepScheduler
muse_512 = PipelineMuse.from_pretrained("openMUSE/muse-512").to("cuda", dtype=torch.float16)
muse_512.transformer.enable_xformers_memory_efficient_attention()
muse_512_fine = PipelineMuse.from_pretrained("openMUSE/muse-512-finetuned").to("cuda", dtype=torch.float16)
muse_512_fine.transformer.enable_xformers_memory_efficient_attention()
sdv1_5 = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16").to("cuda", dtype=torch.float16)
sdv1_5.scheduler = UniPCMultistepScheduler.from_config(sdv1_5.scheduler.config)
sdv1_5.enable_xformers_memory_efficient_attention()
def infer(prompt, negative):
print("Generating:")
muse_512_image = muse_512(
prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0),
)[0]
muse_512_fine_image = muse_512_fine(
prompt, timesteps=16, guidance_scale=10, transformer_seq_len=1024, use_fp16=True, temperature=(2, 0),
)[0]
sdv1_5_image = sdv1_5(prompt, num_inference_steps=25).images[0]
images = [muse_512_image, muse_512_fine_image, sdv1_5_image]
return images
examples = [
[
'A high tech solarpunk utopia in the Amazon rainforest',
'low quality',
10,
],
[
'A pikachu fine dining with a view to the Eiffel Tower',
'low quality',
10,
],
[
'A mecha robot in a favela in expressionist style',
'low quality, 3d, photorealistic',
10,
],
[
'an insect robot preparing a delicious meal',
'low quality, illustration',
10,
],
[
"A small cabin on top of a snowy mountain in the style of Disney, artstation",
'low quality, ugly',
10,
],
]
css = """
h1 {
text-align: center;
}
#component-0 {
max-width: 730px;
margin: auto;
}
"""
block = gr.Blocks(css=css)
with block:
gr.Markdown("MUSE is an upcoming fast text2image model.")
with gr.Group():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
with gr.Column():
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(
label="Generated images", show_label=False,
).style()
with gr.Accordion("Advanced settings", open=False):
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=20, value=10, step=0.1
)
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, negative, guidance_scale], outputs=gallery, cache_examples=False)
ex.dataset.headers = [""]
text.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
btn.click(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
block.launch()