Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from src.transformer import SymmetricTransformer2DModel | |
from src.pipeline import UnifiedPipeline | |
from src.scheduler import Scheduler | |
from torchvision import transforms | |
from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
from diffusers import VQModel | |
import os | |
from PIL import Image | |
import numpy as np | |
import spaces | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def load_models(model_path="MeissonFlow/Meissonic", | |
transformer_path="MeissonFlow/Muddit"): | |
model = SymmetricTransformer2DModel.from_pretrained( | |
transformer_path, | |
subfolder="1024/transformer" | |
) | |
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae") | |
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder") | |
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") | |
scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler") | |
pipe = UnifiedPipeline( | |
vqvae=vq_model, | |
tokenizer=tokenizer, | |
text_encoder=text_encoder, | |
transformer=model, | |
scheduler=scheduler, | |
) | |
return pipe | |
# Load models (global variable to avoid reloading) | |
pipe = load_models() | |
pipe.to(device) | |
# Common transform | |
def get_transform(resolution): | |
return transforms.Compose([ | |
transforms.Resize((resolution, resolution)), | |
transforms.ToTensor(), | |
]) | |
# Image-to-Text Function | |
def image_to_text(image, prompt, seed=42, steps=64, cfg=9.0): | |
try: | |
resolution = 1024 | |
transform = get_transform(resolution) | |
if image is not None: | |
pil_image = Image.fromarray(image.astype('uint8'), 'RGB') if isinstance(image, np.ndarray) else image | |
images = torch.stack([transform(pil_image)]) | |
questions = [prompt] if prompt else ["Please describe this image."] | |
else: | |
images = None | |
questions = [prompt] if prompt else ["Please generate an image description."] | |
output = pipe( | |
prompt=questions, | |
image=images, | |
height=resolution, | |
width=resolution, | |
guidance_scale=cfg, | |
num_inference_steps=steps, | |
mask_token_embedding="./mask_token_embedding.pth", | |
generator=torch.manual_seed(seed), | |
) | |
return output.prompts[0] | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Text-to-Image Function | |
def text_to_image(prompt, negative_prompt, num_images=1, seed=42, steps=64, cfg=9.0): | |
try: | |
resolution = 1024 | |
negative_prompt = negative_prompt or "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
output = pipe( | |
prompt=[prompt]*num_images, | |
negative_prompt=[negative_prompt]*num_images, | |
height=resolution, | |
width=resolution, | |
guidance_scale=cfg, | |
num_inference_steps=steps, | |
mask_token_embedding="./mask_token_embedding.pth", | |
generator=torch.manual_seed(seed), | |
) | |
return output.images | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
return None | |
# Create Gradio interface with Soft theme | |
with gr.Blocks(theme=gr.themes.Soft(), title="Muddit Unifined Model") as demo: | |
gr.Markdown("# π Muddit: Liberating Generation Beyond Text-to-Image with a Unified Discrete Diffusion Model.") | |
gr.Markdown(" Muddit is a unified discrete diffusion transformer that enables fast and parallel generation across both text and image modalities.") | |
with gr.Tab("Image to Text"): | |
with gr.Row(): | |
with gr.Column(): | |
i2t_image_input = gr.Image(label="Upload Image", type="pil") | |
i2t_prompt_input = gr.Textbox(label="Prompt", value="Please describe this image.", placeholder="Enter your prompt here...") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
i2t_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
i2t_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
i2t_submit_btn = gr.Button("Generate Description", variant="primary") | |
with gr.Column(): | |
i2t_output_text = gr.Textbox(label="Generated Description", interactive=False) | |
i2t_examples = gr.Examples( | |
examples=[ | |
["assets/man.jpg"], | |
["assets/tennis.jpg"], | |
["assets/pizza2.jpg"], | |
["assets/plane.jpg"], | |
["assets/zebra.jpg"], | |
["assets/building.jpg"], | |
["assets/flower.jpg"], | |
], | |
inputs=[i2t_image_input], | |
label="Example Inputs" | |
) | |
with gr.Tab("VQA"): | |
with gr.Row(): | |
with gr.Column(): | |
vqa_image_input = gr.Image(label="Upload Image", type="pil") | |
vqa_prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your question here...") | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
vqa_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
vqa_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
vqa_submit_btn = gr.Button("Generate Answer", variant="primary") | |
with gr.Column(): | |
vqa_output_text = gr.Textbox(label="Generated Answer", interactive=False) | |
vqa_examples = gr.Examples( | |
examples=[ | |
["assets/kid.jpg", "What color is the kid's hair?"], | |
["assets/street.jpg", "Can someone legally walk across the street right now?"], | |
["assets/dog.jpg", "Where is the dog laying?"], | |
["assets/dog2.jpg", "What color is the toy the dog is holding?"], | |
["assets/pizza.jpg", "What food item is shown?"], | |
["assets/sheep.jpg", "How many sheep are pictured?"], | |
["assets/car.jpg", "Where are the cars?"], | |
], | |
inputs=[vqa_image_input, vqa_prompt_input], | |
label="Example Inputs" | |
) | |
with gr.Tab("Text to Image"): | |
with gr.Row(): | |
with gr.Column(): | |
t2i_prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...") | |
t2i_negative_prompt = gr.Textbox(label="Negative Prompt", | |
value="worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark", | |
placeholder="What you don't want in the image...", | |
lines=5) | |
t2i_num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, value=1, step=1) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
t2i_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
t2i_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
t2i_submit_btn = gr.Button("Generate Images", variant="primary") | |
with gr.Column(): | |
t2i_gallery = gr.Gallery(label="Generated Images") | |
t2i_examples = gr.Examples( | |
examples=[ | |
["A line art portrait showcasing a human figure with flowing, textured strokes"], | |
["A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head, standing amidst lush, bioluminescent foliage in a vibrant futuristic forest"], | |
["A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear and floral accents, his Mandalorian armor gleaming under the backlighting"], | |
["A translucent, minimalist Porsche 911 GT3RS built from sleek carbon fiber, its aerodynamic body designed in the spirit of '60s Braun and modern Apple minimalism"], | |
["A realistic photograph of a ramadan tent shaped like a crescent moon under a velvety back sky studded with the milky way"], | |
["A portrait of John Lennon, captured in the gritty detail of line art"], | |
["In a world plunged into an unending darkness, remnants of fading starlight pierce through a heavy, smog-filled sky"] | |
], | |
inputs=[t2i_prompt_input], | |
label="Example Prompts" | |
) | |
# Event handlers | |
i2t_submit_btn.click( | |
fn=image_to_text, | |
inputs=[i2t_image_input, i2t_prompt_input, seed, i2t_steps, i2t_cfg], | |
outputs=i2t_output_text | |
) | |
vqa_submit_btn.click( | |
fn=image_to_text, | |
inputs=[vqa_image_input, vqa_prompt_input, seed, vqa_steps, vqa_cfg], | |
outputs=vqa_output_text | |
) | |
t2i_submit_btn.click( | |
fn=text_to_image, | |
inputs=[t2i_prompt_input, t2i_negative_prompt, t2i_num_images, seed, t2i_steps, t2i_cfg], | |
outputs=t2i_gallery | |
) | |
demo.launch() |