amused / app.py
valhalla's picture
Update app.py
9dfb86d
raw
history blame
No virus
7.02 kB
from concurrent.futures import ThreadPoolExecutor
import uuid
import gradio as gr
from PIL import Image
import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import DiffusionPipeline
def save_image(img):
unique_name = str(uuid.uuid4()) + '.png'
img.save(unique_name)
return unique_name
def save_images(image_array):
paths = []
with ThreadPoolExecutor() as executor:
paths = list(executor.map(save_image, image_array))
return paths
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = DiffusionPipeline.from_pretrained(
"amused/amused-512",
variant="fp16",
torch_dtype=torch.float16,
).to(device)
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=True, truncate_long_prompts=False)
def infer(prompt, negative="", scale=10, progress=gr.Progress(track_tqdm=True)):
print("Generating:")
conditioning, pooled = compel(prompt)
negative_conditioning, negative_pooled = compel(negative)
conditioning, negative_conditioning = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])
images = pipe(
prompt_embeds=pooled,
encoder_hidden_states=conditioning,
negative_prompt_embeds=negative_pooled,
negative_encoder_hidden_states=negative_conditioning,
guidance_scale=scale,
num_images_per_prompt=4,
temperature=(3, 1),
).images
print("Done Generating!")
print("Num Images:", len(images))
return images
examples = [
[
'A serious capybara at work, wearing a suit',
None,
None,
],
[
'A pikachu fine dining with a view to the Eiffel Tower',
None,
None,
],
[
'A mecha robot in a favela in expressionist style',
None,
None,
],
[
'an insect robot preparing a delicious meal',
None,
None,
],
[
"A small cabin on top of a snowy mountain in the style of Disney, artstation",
None,
None,
],
]
css = """
h1 {
text-align: center;
}
#component-0 {
max-width: 730px;
margin: auto;
}
"""
block = gr.Blocks(css=css)
with block:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<svg
width="0.65em"
height="0.65em"
viewBox="0 0 115 115"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<rect width="23" height="23" fill="white"></rect>
<rect y="69" width="23" height="23" fill="white"></rect>
<rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="46" width="23" height="23" fill="white"></rect>
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" width="23" height="23" fill="black"></rect>
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="115" y="46" width="23" height="23" fill="white"></rect>
<rect x="115" y="115" width="23" height="23" fill="white"></rect>
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
<rect x="69" y="46" width="23" height="23" fill="white"></rect>
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="46" y="46" width="23" height="23" fill="black"></rect>
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
<rect x="46" y="69" width="23" height="23" fill="black"></rect>
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
</svg>
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
aMUSEd: Efficient Text-to-Image Model
</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
<a style="text-decoration: underline;" href="https://arxiv.org/abs/2401.01808", target="_blank">aMUSEd</a> is an open-source, lightweight masked image model for text-to-image generation based on MUSE focused on fast image generation.
</p>
</div>
"""
)
with gr.Group():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
with gr.Column():
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
negative = gr.Textbox(
label="Enter your negative prompt",
show_label=False,
max_lines=1,
placeholder="Enter your negative prompt",
container=False,
value="low quality, ugly, deformed"
)
btn = gr.Button("Generate image", scale=0)
gallery = gr.Gallery(
label="Generated images", show_label=False,
).style(grid=[2])
with gr.Accordion("Advanced settings", open=False):
guidance_scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=20, value=10, step=0.1
)
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, negative, guidance_scale], outputs=gallery, cache_examples=False)
ex.dataset.headers = [""]
text.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
negative.submit(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
btn.click(infer, inputs=[text, negative, guidance_scale], outputs=gallery)
block.launch()