Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import json | |
import base64 | |
from datetime import datetime | |
import numpy as np | |
import torch | |
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, DDIMScheduler | |
from controlnet_aux import AnylineDetector | |
from compel import Compel, ReturnedEmbeddingsType | |
from PIL import Image | |
import pandas as pd | |
# Configuration | |
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1" | |
IS_SPACE = os.environ.get("SPACE_ID", None) is not None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1" | |
print(f"device: {device}") | |
print(f"dtype: {dtype}") | |
print(f"low memory: {LOW_MEMORY}") | |
# Model initialization | |
model = "stabilityai/stable-diffusion-xl-base-1.0" | |
scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler") | |
controlnet = ControlNetModel.from_pretrained( | |
"TheMistoAI/MistoLine", | |
torch_dtype=torch.float16, | |
revision="refs/pr/3", | |
variant="fp16", | |
) | |
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( | |
model, | |
controlnet=controlnet, | |
torch_dtype=dtype, | |
variant="fp16", | |
use_safetensors=True, | |
scheduler=scheduler, | |
) | |
compel = Compel( | |
tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
requires_pooled=[False, True], | |
) | |
pipe = pipe.to(device) | |
anyline = AnylineDetector.from_pretrained( | |
"TheMistoAI/MistoLine", filename="MTEED.pth", subfolder="Anyline" | |
).to(device) | |
# Global variables for metadata and likes cache | |
image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created']) | |
LIKES_CACHE_FILE = "likes_cache.json" | |
def load_likes_cache(): | |
if os.path.exists(LIKES_CACHE_FILE): | |
with open(LIKES_CACHE_FILE, 'r') as f: | |
return json.load(f) | |
return {} | |
def save_likes_cache(cache): | |
with open(LIKES_CACHE_FILE, 'w') as f: | |
json.dump(cache, f) | |
likes_cache = load_likes_cache() | |
def pad_image(image): | |
w, h = image.size | |
if w == h: | |
return image | |
elif w > h: | |
new_image = Image.new(image.mode, (w, w), (0, 0, 0)) | |
new_image.paste(image, (0, (w - h) // 2)) | |
return new_image | |
else: | |
new_image = Image.new(image.mode, (h, h), (0, 0, 0)) | |
new_image.paste(image, ((h - w) // 2, 0)) | |
return new_image | |
def create_download_link(filename): | |
with open(filename, "rb") as file: | |
encoded_string = base64.b64encode(file.read()).decode('utf-8') | |
download_link = f'<a href="data:image/png;base64,{encoded_string}" download="{filename}">Download Image</a>' | |
return download_link | |
def save_image(image: Image.Image, prompt: str) -> str: | |
global image_metadata, likes_cache | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
safe_prompt = ''.join(e for e in prompt if e.isalnum() or e.isspace())[:50] | |
filename = f"{timestamp}_{safe_prompt}.png" | |
image.save(filename) | |
new_row = pd.DataFrame({ | |
'Filename': [filename], | |
'Prompt': [prompt], | |
'Likes': [0], | |
'Dislikes': [0], | |
'Hearts': [0], | |
'Created': [datetime.now()] | |
}) | |
image_metadata = pd.concat([image_metadata, new_row], ignore_index=True) | |
likes_cache[filename] = {'likes': 0, 'dislikes': 0, 'hearts': 0} | |
save_likes_cache(likes_cache) | |
return filename | |
def get_image_gallery(): | |
global image_metadata | |
image_files = image_metadata['Filename'].tolist() | |
return [(file, get_image_caption(file)) for file in image_files if os.path.exists(file)] | |
def get_image_caption(filename): | |
global likes_cache, image_metadata | |
if filename in likes_cache: | |
likes = likes_cache[filename]['likes'] | |
dislikes = likes_cache[filename]['dislikes'] | |
hearts = likes_cache[filename]['hearts'] | |
prompt = image_metadata[image_metadata['Filename'] == filename]['Prompt'].values[0] | |
return f"{filename}\nPrompt: {prompt}\n👍 {likes} 👎 {dislikes} ❤️ {hearts}" | |
return filename | |
def delete_all_images(): | |
global image_metadata, likes_cache | |
for file in image_metadata['Filename']: | |
if os.path.exists(file): | |
os.remove(file) | |
image_metadata = pd.DataFrame(columns=['Filename', 'Prompt', 'Likes', 'Dislikes', 'Hearts', 'Created']) | |
likes_cache = {} | |
save_likes_cache(likes_cache) | |
return get_image_gallery(), image_metadata.values.tolist() | |
def delete_image(filename): | |
global image_metadata, likes_cache | |
if filename and os.path.exists(filename): | |
os.remove(filename) | |
image_metadata = image_metadata[image_metadata['Filename'] != filename] | |
if filename in likes_cache: | |
del likes_cache[filename] | |
save_likes_cache(likes_cache) | |
return get_image_gallery(), image_metadata.values.tolist() | |
def vote(filename, vote_type): | |
global likes_cache | |
if filename in likes_cache: | |
likes_cache[filename][vote_type.lower()] += 1 | |
save_likes_cache(likes_cache) | |
return get_image_gallery(), image_metadata.values.tolist() | |
def predict( | |
input_image, | |
prompt, | |
negative_prompt, | |
seed, | |
guidance_scale=8.5, | |
controlnet_conditioning_scale=0.5, | |
strength=1.0, | |
controlnet_start=0.0, | |
controlnet_end=1.0, | |
guassian_sigma=2.0, | |
intensity_threshold=3, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if input_image is None: | |
raise gr.Error("Please upload an image.") | |
padded_image = pad_image(input_image).resize((1024, 1024)).convert("RGB") | |
conditioning, pooled = compel([prompt, negative_prompt]) | |
generator = torch.manual_seed(seed) | |
last_time = time.time() | |
anyline_image = anyline( | |
padded_image, | |
detect_resolution=1280, | |
guassian_sigma=max(0.01, guassian_sigma), | |
intensity_threshold=intensity_threshold, | |
) | |
images = pipe( | |
image=padded_image, | |
control_image=anyline_image, | |
strength=strength, | |
prompt_embeds=conditioning[0:1], | |
pooled_prompt_embeds=pooled[0:1], | |
negative_prompt_embeds=conditioning[1:2], | |
negative_pooled_prompt_embeds=pooled[1:2], | |
width=1024, | |
height=1024, | |
controlnet_conditioning_scale=float(controlnet_conditioning_scale), | |
controlnet_start=float(controlnet_start), | |
controlnet_end=float(controlnet_end), | |
generator=generator, | |
num_inference_steps=30, | |
guidance_scale=guidance_scale, | |
eta=1.0, | |
) | |
print(f"Time taken: {time.time() - last_time}") | |
generated_image = images.images[0] | |
filename = save_image(generated_image, prompt) | |
download_link = create_download_link(filename) | |
return (padded_image, generated_image), padded_image, anyline_image, download_link, get_image_gallery(), image_metadata.values.tolist() | |
css = """ | |
#intro { | |
max-width: 100%; | |
text-align: center; | |
margin: 0 auto; | |
} | |
.gradio-container {max-width: 1200px !important} | |
footer {visibility: hidden} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🎨 ArtForge: MistoLine ControlNet Masterpiece Gallery | |
Create, curate, and compete with AI-enhanced images using MistoLine ControlNet. Join our creative multiplayer experience! 🖼️🏆✨ | |
This demo showcases the capabilities of [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine) ControlNet with SDXL. | |
- SDXL Controlnet: [TheMistoAI/MistoLine](https://huggingface.co/TheMistoAI/MistoLine) | |
- [Anyline with Controlnet Aux](https://github.com/huggingface/controlnet_aux) | |
- For upscaling, see [Enhance This Demo](https://huggingface.co/spaces/radames/Enhance-This-HiDiffusion-SDXL) | |
""", | |
elem_id="intro", | |
) | |
with gr.Tab("Generate Images"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Input Image") | |
prompt = gr.Textbox( | |
label="Prompt", | |
info="The prompt is very important to get the desired results. Please try to describe the image as best as you can. Accepts Compel Syntax", | |
) | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
) | |
seed = gr.Slider( | |
minimum=0, | |
maximum=2**64 - 1, | |
value=1415926535897932, | |
step=1, | |
label="Seed", | |
randomize=True, | |
) | |
with gr.Accordion(label="Advanced", open=False): | |
guidance_scale = gr.Slider( | |
minimum=0, | |
maximum=50, | |
value=8.5, | |
step=0.001, | |
label="Guidance Scale", | |
) | |
controlnet_conditioning_scale = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.001, | |
value=0.5, | |
label="ControlNet Conditioning Scale", | |
) | |
strength = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.001, | |
value=1, | |
label="Strength", | |
) | |
controlnet_start = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.001, | |
value=0.0, | |
label="ControlNet Start", | |
) | |
controlnet_end = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.001, | |
value=1.0, | |
label="ControlNet End", | |
) | |
guassian_sigma = gr.Slider( | |
minimum=0.01, | |
maximum=10.0, | |
step=0.1, | |
value=2.0, | |
label="(Anyline) Guassian Sigma", | |
) | |
intensity_threshold = gr.Slider( | |
minimum=0, | |
maximum=255, | |
step=1, | |
value=3, | |
label="(Anyline) Intensity Threshold", | |
) | |
btn = gr.Button("Generate") | |
with gr.Column(scale=2): | |
with gr.Group(): | |
image_slider = ImageSlider(position=0.5) | |
with gr.Row(): | |
padded_image = gr.Image(type="pil", label="Padded Image") | |
anyline_image = gr.Image(type="pil", label="Anyline Image") | |
download_link = gr.HTML(label="Download Generated Image") | |
with gr.Tab("Gallery and Voting"): | |
image_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=4, height="auto") | |
with gr.Row(): | |
like_button = gr.Button("👍 Like") | |
dislike_button = gr.Button("👎 Dislike") | |
heart_button = gr.Button("❤️ Heart") | |
delete_image_button = gr.Button("🗑️ Delete Selected Image") | |
selected_image = gr.State(None) | |
with gr.Tab("Metadata and Management"): | |
metadata_df = gr.Dataframe( | |
label="Image Metadata", | |
headers=["Filename", "Prompt", "Likes", "Dislikes", "Hearts", "Created"], | |
interactive=False | |
) | |
delete_all_button = gr.Button("🗑️ Delete All Images") | |
inputs = [ | |
image_input, | |
prompt, | |
negative_prompt, | |
seed, | |
guidance_scale, | |
controlnet_conditioning_scale, | |
strength, | |
controlnet_start, | |
controlnet_end, | |
guassian_sigma, | |
intensity_threshold, | |
] | |
outputs = [image_slider, padded_image, anyline_image, download_link, image_gallery, metadata_df] | |
btn.click(fn=predict, inputs=inputs, outputs=outputs) | |
image_gallery.select(fn=lambda evt: evt, inputs=[], outputs=[selected_image]) | |
like_button.click(fn=lambda x: vote(x, 'likes'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) | |
dislike_button.click(fn=lambda x: vote(x, 'dislikes'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) | |
heart_button.click(fn=lambda x: vote(x, 'hearts'), inputs=[selected_image], outputs=[image_gallery, metadata_df]) | |
delete_image_button.click(fn=deletedelete_image_button.click(fn=delete_image, inputs=[selected_image], outputs=[image_gallery, metadata_df]) | |
delete_all_button.click(fn=delete_all_images, inputs=[], outputs=[image_gallery, metadata_df]) | |
demo.load(fn=lambda: (get_image_gallery(), image_metadata.values.tolist()), outputs=[image_gallery, metadata_df]) | |
gr.Examples( | |
fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
examples=[ | |
[ | |
"./examples/city.png", | |
"hyperrealistic surreal cityscape scene at sunset, buildings", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
13113544138610326000, | |
8.5, | |
0.481, | |
1.0, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/lara.jpeg", | |
"photography of lara croft 8k high definition award winning", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
5436236241, | |
8.5, | |
0.8, | |
1.0, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/cybetruck.jpeg", | |
"photo of tesla cybertruck futuristic car 8k high definition on a sand dune in mars, future", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
383472451451, | |
8.5, | |
0.8, | |
0.8, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/jesus.png", | |
"a photorealistic painting of Jesus Christ, 4k high definition", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
13317204146129588000, | |
8.5, | |
0.8, | |
0.8, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/anna-sullivan-DioLM8ViiO8-unsplash.jpg", | |
"A crowded stadium with enthusiastic fans watching a daytime sporting event, the stands filled with colorful attire and the sun casting a warm glow", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
5623124123512, | |
8.5, | |
0.8, | |
0.8, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/img_aef651cb-2919-499d-aa49-6d4e2e21a56e_1024.jpg", | |
"a large red flower on a black background 4k high definition", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic", | |
23123412341234, | |
8.5, | |
0.8, | |
0.8, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
[ | |
"./examples/huggingface.jpg", | |
"photo realistic huggingface human emoji costume, round, yellow, (human skin)+++ (human texture)+++", | |
"blurry, ugly, duplicate, poorly drawn, deformed, mosaic, emoji cartoon, drawing, pixelated", | |
12312353423, | |
15.206, | |
0.364, | |
0.8, | |
0.0, | |
0.9, | |
2, | |
3, | |
], | |
], | |
cache_examples=True, | |
) | |
demo.queue(concurrency_count=1, max_size=20).launch(debug=True) |