CatHeiHei_v1 / app.py
Next7years
add default negative
c880a7e
import gradio as gr
from PIL import Image
import re
import os
import io
import base64
# import requests
from diffusers import StableDiffusionPipeline
import torch
from share_btn import community_icon_html, loading_icon_html, share_js
model_id = "Next7years/stable-diffusion-v1-5-CatHeiHei-v1"
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("info: running device type: " + device.type )
#word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
#word_list = word_list_dataset["train"]['text']
default_negative_prompt="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face, blurry, draft, grainy"
example_dir = "prompt_examples"
is_gpu_busy = False
'''
def infer(prompt):
global is_gpu_busy
samples = 4
steps = 50
scale = 7.5
#for filter in word_list:
# if re.search(rf"\b{filter}\b", prompt):
# raise gr.Error("Unsafe content found. Please try again with different prompts.")
images = []
url = os.getenv('JAX_BACKEND_URL')
payload = {'prompt': prompt}
images_request = requests.post(url, json = payload)
for image in images_request.json()["images"]:
image_b64 = (f"data:image/jpeg;base64,{image}")
images.append(image_b64)
return images
'''
def infer(prompt):
samples = 4
steps = 50
scale = 7.5
if device.type == "cuda" or device.type == "mps":
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to(device)
images = []
results = pipe(prompt, negative_prompt=default_negative_prompt, num_images_per_prompt=samples, num_inference_steps=steps, guidance_scale=scale).images
print(results)
for image in results:
jpeg_image = io.BytesIO()
image.save(jpeg_image, format='JPEG')
base64_image = base64.b64encode(jpeg_image.getvalue()).decode('utf-8')
image_b64 = (f"data:image/jpeg;base64,{base64_image}")
#print(image_b64)
images.append(image_b64)
return images
css = """
.gradio-container {
font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
color: white;
border-color: black;
background: black;
}
input[type='range'] {
accent-color: black;
}
.dark input[type='range'] {
accent-color: #dfdfdf;
}
.container {
max-width: 730px;
margin: auto;
padding-top: 1.5rem;
}
#gallery {
min-height: 22rem;
margin-bottom: 15px;
margin-left: auto;
margin-right: auto;
border-bottom-right-radius: .5rem !important;
border-bottom-left-radius: .5rem !important;
}
#gallery>div>.h-full {
min-height: 20rem;
}
.details:hover {
text-decoration: underline;
}
.gr-button {
white-space: nowrap;
}
.gr-button:focus {
border-color: rgb(147 197 253 / var(--tw-border-opacity));
outline: none;
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
--tw-border-opacity: 1;
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
--tw-ring-opacity: .5;
}
#advanced-btn {
font-size: .7rem !important;
line-height: 19px;
margin-top: 12px;
margin-bottom: 12px;
padding: 2px 8px;
border-radius: 14px !important;
}
#advanced-options {
display: none;
margin-bottom: 20px;
}
.footer {
margin-bottom: 45px;
margin-top: 35px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
.acknowledgments h4{
margin: 1.25em 0 .25em 0;
font-weight: bold;
font-size: 115%;
}
#container-advanced-btns{
display: flex;
flex-wrap: wrap;
justify-content: space-between;
align-items: center;
}
.animate-spin {
animation: spin 1s linear infinite;
}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
.gr-form{
flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
}
#prompt-container{
gap: 0;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
"""
block = gr.Blocks(css=css)
def read_files_from_directory(directory):
file_contents = []
for filename in os.listdir(directory):
if filename.endswith(".txt"):
file_path = os.path.join(directory, filename)
with open(file_path, 'r') as f:
content = f.read()
file_contents.append([content])
return file_contents
examples = read_files_from_directory(example_dir)
metadata = [
{"title": "Positive Example",
"description": "A positive example input.",
"thumbnail": "https://example.com/images/positive.jpg",
"label": "Positive"},
{"title": "Negative Example",
"description": "A negative example input.",
"thumbnail": "https://example.com/images/negative.jpg",
"label": "Negative"}
]
with block:
gr.HTML(
"""
<div style="text-align: center; max-width: 650px; margin: 0 auto; padding-top: 7px;">
<div
style="
display: inline-flex;
align-items: center;
gap: 0.8rem;
font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 12px;">
Welcome to CatHeiHei v1 Model
</h1>
</div>
<p style="margin-bottom: 10px">
We're excited to open-source this unique AI model, designed specifically to generate images of the world-famous Cat HeiHei.
Our goal is to foster creativity and collaboration within the community, and we can't wait to see the amazing artwork you'll create!
<p>Follow us on Instagram:
<a href="https://www.instagram.com/cat_heihei/" style="display: inline;">
<img src="https://www.instagram.com/static/images/ico/favicon-192.png/68d99ba29cc8.png" alt="Instagram Logo" style="display: inline; width: 20px; height: 20px; margin-right: 5px;">
@cat_heihei
</a>
</p>
</p>
</div>
"""
)
with gr.Group():
with gr.Box():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
elem_id="prompt-text-input",
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
with gr.Group(elem_id="container-advanced-btns"):
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
with gr.Row(elem_id="advanced-options"):
gr.Markdown("Advanced settings are temporarily unavailable")
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
scale = gr.Slider(
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
ex = gr.Examples(examples=examples, label="Example prompt to generate CatHeiHei", fn=infer, inputs=text, outputs=[gallery], cache_examples=False, postprocess=False)
ex.dataset.headers = [""]
text.submit(infer, inputs=text, outputs=[gallery], postprocess=False)
btn.click(infer, inputs=text, outputs=[gallery], postprocess=False)
advanced_button.click(
None,
[],
text,
_js="""
() => {
const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
}""",
)
share_button.click(
None,
[],
[],
_js=share_js,
)
gr.HTML(
"""
<div class="footer" style="padding-top: 20px;">
<p>Thanks for using the <b>Cat HeiHei </b> Customized Model! We appreciate your support and creativity. Share your feedback, suggestions, and content ideas by messaging us on Instagram
<a href="https://www.instagram.com/cat_heihei/" style="display: inline;">
<img src="https://www.instagram.com/static/images/ico/favicon-192.png/68d99ba29cc8.png" alt="Instagram Logo" style="display: inline; width: 20px; height: 20px; margin-right: 5px;">
@cat_heihei
</a>.
Let's make the Cat HeiHei community a fun, creative space for all! Happy creating! 🐾💖
</p>
</div>
"""
)
block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)