Spaces:
Running
Running
import torch | |
import gradio as gr | |
from diffusers import DiffusionPipeline | |
from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler | |
from huggingface_hub import hf_hub_download | |
# Load the Stable Diffusion XL model | |
pipe = DiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
).to("cuda") | |
# Load LoRA weights | |
pipe.load_lora_weights("fofr/sdxl-emoji", weight_name="lora.safetensors") | |
# Setup text encoders and tokenizers | |
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] | |
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] | |
# Load the emoji embeddings | |
embedding_path = hf_hub_download(repo_id="fofr/sdxl-emoji", filename="embeddings.pti", repo_type="model") | |
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) | |
embhandler.load_embeddings(embedding_path) | |
# Gradio generation function | |
def generate_emoji(prompt, scale): | |
# Add token embeddings to the prompt | |
prompt = f"A <s0><s1> emoji of a {prompt}" | |
# Generate the image using the diffusion pipeline | |
images = pipe(prompt, cross_attention_kwargs={"scale": scale}).images | |
# Return the generated image | |
return images[0] | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=generate_emoji, | |
inputs=[ | |
gr.Textbox(label="Description of the emoji (e.g., 'man', 'woman')", placeholder="Type here..."), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="Cross Attention Scale") | |
], | |
outputs="image", | |
title="Emoji Generator", | |
description="Generate custom emojis using Stable Diffusion XL with LoRA weights." | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() | |