ehristoforu's picture
Update app.py
1b5b3a3
raw
history blame
6.71 kB
import gradio as gr
import requests
import io
import re
import random
import os
from PIL import Image
from datasets import load_dataset
from huggingface_hub import login
login(token=os.getenv("HF_READ_TOKEN"))
API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1"
API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free
headers = {"Authorization": f"Bearer {API_TOKEN}"}
word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True)
word_list = word_list_dataset["train"]['text']
def query(prompt, is_negative=False, steps=7, cfg_scale=7, seed=None, num_images=4):
for filter in word_list:
if re.search(rf"\b{filter}\b", prompt):
raise gr.Error("Unsafe content found. Please try again with different prompts.")
images = []
for _ in range(num_images):
payload = {
"inputs": prompt + ", 8k",
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed is not None else random.randint(-1, 2147483647)
}
image_bytes = requests.post(API_URL, headers=headers, json=payload).content
image = Image.open(io.BytesIO(image_bytes))
images.append(image)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
.container {
max-width: 730px;
margin: auto;
padding-top: 1.5rem;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
#advanced-btn {
font-size: .7rem !important;
line-height: 19px;
margin-top: 12px;
margin-bottom: 12px;
padding: 2px 8px;
border-radius: 14px !important;
}
#advanced-options {
display: none;
margin-bottom: 20px;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.acknowledgments h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
margin-top: 10px;
margin-left: auto;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
.gr-form{
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
}
#prompt-container{
gap: 0;
}
#prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem}
#component-16{border-top-width: 1px!important;margin-top: 1em}
.image_duplication{position: absolute; width: 100px; left: 50px}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
Open Diffusion 1.0 Demo
</h1>
</div>
</div>
"""
)
with gr.Group():
with gr.Row(elem_id="prompt-container"):
with gr.Column():
gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[4], height="auto")
with gr.Row():
with gr.Column():
text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1)
negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1)
text_button = gr.Button("Generate")
text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output)
demo.launch(show_api=False)