Spaces:
Running
on
A10G
Running
on
A10G
from concurrent.futures import ThreadPoolExecutor | |
import uuid | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from compel import Compel, ReturnedEmbeddingsType | |
from diffusers import DiffusionPipeline | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + '.png' | |
img.save(unique_name) | |
return unique_name | |
def save_images(image_array): | |
paths = [] | |
with ThreadPoolExecutor() as executor: | |
paths = list(executor.map(save_image, image_array)) | |
return paths | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = DiffusionPipeline.from_pretrained( | |
"amused/amused-512", | |
variant="fp16", | |
torch_dtype=torch.float16, | |
).to(device) | |
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True, truncate_long_prompts=False) | |
def infer(prompt, negative="", scale=10, progress=gr.Progress(track_tqdm=True)): | |
print("Generating:") | |
conditioning, pooled = compel(prompt) | |
negative_conditioning, negative_pooled = compel(negative) | |
conditioning, negative_conditioning = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning]) | |
images = pipe( | |
prompt_embeds=pooled, | |
encoder_hidden_states=conditioning, | |
negative_prompt_embeds=negative_pooled, | |
negative_encoder_hidden_states=negative_conditioning, | |
guidance_scale=scale, | |
num_images_per_prompt=4, | |
temperature=(3, 1), | |
).images | |
print("Done Generating!") | |
print("Num Images:", len(images)) | |
return images | |
examples = [ | |
[ | |
'A serious capybara at work, wearing a suit', | |
None, | |
None, | |
], | |
[ | |
'A pikachu fine dining with a view to the Eiffel Tower', | |
None, | |
None, | |
], | |
[ | |
'A mecha robot in a favela in expressionist style', | |
None, | |
None, | |
], | |
[ | |
'an insect robot preparing a delicious meal', | |
None, | |
None, | |
], | |
[ | |
"A small cabin on top of a snowy mountain in the style of Disney, artstation", | |
None, | |
None, | |
], | |
] | |
css = """ | |
h1 { | |
text-align: center; | |
} | |
#component-0 { | |
max-width: 730px; | |
margin: auto; | |
} | |
""" | |
block = gr.Blocks(css=css) | |
with block: | |
gr.HTML( | |
""" | |
<div style="text-align: center; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
" | |
> | |
<svg | |
width="0.65em" | |
height="0.65em" | |
viewBox="0 0 115 115" | |
fill="none" | |
xmlns="http://www.w3.org/2000/svg" | |
> | |
<rect width="23" height="23" fill="white"></rect> | |
<rect y="69" width="23" height="23" fill="white"></rect> | |
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="46" width="23" height="23" fill="white"></rect> | |
<rect x="46" y="69" width="23" height="23" fill="white"></rect> | |
<rect x="69" width="23" height="23" fill="black"></rect> | |
<rect x="69" y="69" width="23" height="23" fill="black"></rect> | |
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect> | |
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="115" y="46" width="23" height="23" fill="white"></rect> | |
<rect x="115" y="115" width="23" height="23" fill="white"></rect> | |
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect> | |
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="92" y="69" width="23" height="23" fill="white"></rect> | |
<rect x="69" y="46" width="23" height="23" fill="white"></rect> | |
<rect x="69" y="115" width="23" height="23" fill="white"></rect> | |
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect> | |
<rect x="46" y="46" width="23" height="23" fill="black"></rect> | |
<rect x="46" y="115" width="23" height="23" fill="black"></rect> | |
<rect x="46" y="69" width="23" height="23" fill="black"></rect> | |
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect> | |
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect> | |
<rect x="23" y="69" width="23" height="23" fill="black"></rect> | |
</svg> | |
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px"> | |
aMUSEd: Efficient Text-to-Image Model | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;"> | |
<a style="text-decoration: underline;" href="https://arxiv.org/abs/2401.01808", target="_blank">aMUSEd</a> is an open-source, lightweight masked image model for text-to-image generation based on MUSE focused on fast image generation. | |
</p> | |
</div> | |
""" | |
) | |
with gr.Group(): | |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
text = gr.Textbox( | |
label="Enter your prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
negative = gr.Textbox( | |
label="Enter your negative prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your negative prompt", | |
container=False, | |
value="low quality, ugly, deformed" | |
) | |
btn = gr.Button("Generate image", scale=0) | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, | |
).style(grid=[2]) | |
with gr.Accordion("Advanced settings", open=False): | |
guidance_scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=20, value=10, step=0.1 | |
) | |
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, negative, guidance_scale], outputs=gallery, cache_examples=False) | |
ex.dataset.headers = [""] | |
text.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
btn.click(infer, inputs=[text, negative, guidance_scale], outputs=gallery) | |
block.launch() |