ehristoforu's picture
Update app.py
22deffa
raw
history blame
3.24 kB
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login
login(token=os.getenv("HF_READ_TOKEN"))
API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}
word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']
def query(prompt, is_negative=False, steps=7, cfg_scale=7, seed=None, num_images=4):
for filter in word_list:
if re.search(rf"\b{filter}\b", prompt):
raise gr.Error("Unsafe content found. Please try again with different prompts.")
images = []
for _ in range(num_images):
payload = {
"inputs": prompt + ", 8k",
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed is not None else random.randint(-1, 2147483647)
}
image_bytes = requests.post(API_URL, headers=headers, json=payload).content
image = Image.open(io.BytesIO(image_bytes))
images.append(image)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Open Diffusion 1.0 Demo
</h1>
</div>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
with gr.Row(elem_id="prompt-container"):
with gr.Column():
text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1, elem_id="prompt-text-input")
negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1, elem_id="negative-prompt-text-input")
text_button = gr.Button("Generate").style(margin=False, rounded=(False, True, True, False), full_width=False)
text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)
demo.launch(show_api=False)