import gradio as gr from PIL import Image import re import os import io import base64 # import requests from diffusers import StableDiffusionPipeline import torch from share_btn import community_icon_html, loading_icon_html, share_js model_id = "Next7years/stable-diffusion-v1-5-CatHeiHei-v1" device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print("info: running device type: " + device.type ) #word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True) #word_list = word_list_dataset["train"]['text'] default_negative_prompt="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face, blurry, draft, grainy" example_dir = "prompt_examples" is_gpu_busy = False ''' def infer(prompt): global is_gpu_busy samples = 4 steps = 50 scale = 7.5 #for filter in word_list: # if re.search(rf"\b{filter}\b", prompt): # raise gr.Error("Unsafe content found. Please try again with different prompts.") images = [] url = os.getenv('JAX_BACKEND_URL') payload = {'prompt': prompt} images_request = requests.post(url, json = payload) for image in images_request.json()["images"]: image_b64 = (f"data:image/jpeg;base64,{image}") images.append(image_b64) return images ''' def infer(prompt): samples = 4 steps = 50 scale = 7.5 if device.type == "cuda" or device.type == "mps": pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) else: pipe = StableDiffusionPipeline.from_pretrained(model_id) pipe = pipe.to(device) images = [] results = pipe(prompt, negative_prompt=default_negative_prompt, num_images_per_prompt=samples, num_inference_steps=steps, guidance_scale=scale).images print(results) for image in results: jpeg_image = io.BytesIO() image.save(jpeg_image, format='JPEG') base64_image = base64.b64encode(jpeg_image.getvalue()).decode('utf-8') image_b64 = (f"data:image/jpeg;base64,{base64_image}") #print(image_b64) images.append(image_b64) return images css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .gr-button { color: white; border-color: black; background: black; } input[type='range'] { accent-color: black; } .dark input[type='range'] { accent-color: #dfdfdf; } .container { max-width: 730px; margin: auto; padding-top: 1.5rem; } #gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important; } #gallery>div>.h-full { min-height: 20rem; } .details:hover { text-decoration: underline; } .gr-button { white-space: nowrap; } .gr-button:focus { border-color: rgb(147 197 253 / var(--tw-border-opacity)); outline: none; box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); --tw-border-opacity: 1; --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); --tw-ring-opacity: .5; } #advanced-btn { font-size: .7rem !important; line-height: 19px; margin-top: 12px; margin-bottom: 12px; padding: 2px 8px; border-radius: 14px !important; } #advanced-options { display: none; margin-bottom: 20px; } .footer { margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } .acknowledgments h4{ margin: 1.25em 0 .25em 0; font-weight: bold; font-size: 115%; } #container-advanced-btns{ display: flex; flex-wrap: wrap; justify-content: space-between; align-items: center; } .animate-spin { animation: spin 1s linear infinite; } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; } #share-btn { all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; } #share-btn * { all: unset; } .gr-form{ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0; } #prompt-container{ gap: 0; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } """ block = gr.Blocks(css=css) def read_files_from_directory(directory): file_contents = [] for filename in os.listdir(directory): if filename.endswith(".txt"): file_path = os.path.join(directory, filename) with open(file_path, 'r') as f: content = f.read() file_contents.append([content]) return file_contents examples = read_files_from_directory(example_dir) metadata = [ {"title": "Positive Example", "description": "A positive example input.", "thumbnail": "https://example.com/images/positive.jpg", "label": "Positive"}, {"title": "Negative Example", "description": "A negative example input.", "thumbnail": "https://example.com/images/negative.jpg", "label": "Negative"} ] with block: gr.HTML( """

Welcome to CatHeiHei v1 Model

We're excited to open-source this unique AI model, designed specifically to generate images of the world-famous Cat HeiHei. Our goal is to foster creativity and collaboration within the community, and we can't wait to see the amazing artwork you'll create!

Follow us on Instagram: Instagram Logo @cat_heihei

""" ) with gr.Group(): with gr.Box(): with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Enter your prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", elem_id="prompt-text-input", ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(False, True, True, False), full_width=False, ) gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") with gr.Group(elem_id="container-advanced-btns"): advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_html) loading_icon = gr.HTML(loading_icon_html) share_button = gr.Button("Share to community", elem_id="share-btn") with gr.Row(elem_id="advanced-options"): gr.Markdown("Advanced settings are temporarily unavailable") samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1) steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1) scale = gr.Slider( label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 ) seed = gr.Slider( label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True, ) ex = gr.Examples(examples=examples, label="Example prompt to generate CatHeiHei", fn=infer, inputs=text, outputs=[gallery], cache_examples=False, postprocess=False) ex.dataset.headers = [""] text.submit(infer, inputs=text, outputs=[gallery], postprocess=False) btn.click(infer, inputs=text, outputs=[gallery], postprocess=False) advanced_button.click( None, [], text, _js=""" () => { const options = document.querySelector("body > gradio-app").querySelector("#advanced-options"); options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none"; }""", ) share_button.click( None, [], [], _js=share_js, ) gr.HTML( """ """ ) block.queue(concurrency_count=40, max_size=20).launch(max_threads=150)