Spaces:
Build error
Build error
import gradio as gr | |
from PIL import Image | |
import requests | |
from diffusers import StableDiffusionPipeline | |
# Load models using diffusers | |
general_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") | |
anime_model = StableDiffusionPipeline.from_pretrained("hakurei/waifu-diffusion") | |
# Placeholder functions for the actual implementations | |
def check_anime_image(image): | |
# Use SauceNAO or similar service to check if the image is anime | |
# and fetch similar images and tags | |
return False, [], [] | |
def describe_image_general(image): | |
# Use the general model to describe the image | |
description = general_model(image) | |
return description | |
def describe_image_anime(image): | |
# Use the anime model to describe the image | |
description = anime_model(image) | |
return description | |
def merge_tags(tags1, tags2): | |
# Merge tags, removing duplicates | |
return list(set(tags1 + tags2)) | |
# Gradio app functions | |
def process_image(image, mode): | |
# Convert the image to a format suitable for the models | |
image = image.resize((256, 256)) | |
if mode == "Anime": | |
is_anime, similar_images, original_tags = check_anime_image(image) | |
if is_anime: | |
tags = describe_image_anime(image) | |
return tags, original_tags | |
else: | |
return ["Not an anime image"], [] | |
else: | |
tags = describe_image_general(image) | |
return tags, [] | |
def describe(image, mode): | |
tags, original_tags = process_image(image, mode) | |
return gr.update(value="\n".join(tags)), gr.update(value="\n".join(original_tags)) | |
def merge(tags, original_tags): | |
merged_tags = merge_tags(tags.split("\n"), original_tags.split("\n")) | |
return "\n".join(merged_tags) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image_input = gr.Image(type="pil", tool="editor", label="Upload/Paste Image") | |
mode = gr.Dropdown(choices=["Anime", "General"], label="Mode") | |
describe_button = gr.Button("Describe") | |
merge_button = gr.Button("Merge Tags") | |
with gr.TabGroup() as tab_group: | |
with gr.TabItem("Described Tags"): | |
described_tags = gr.TextArea(label="Described Tags") | |
with gr.TabItem("Original Tags"): | |
original_tags = gr.TextArea(label="Original Tags") | |
merged_tags = gr.TextArea(label="Merged Tags") | |
describe_button.click(describe, inputs=[image_input, mode], outputs=[described_tags, original_tags]) | |
merge_button.click(merge, inputs=[described_tags, original_tags], outputs=merged_tags) | |
demo.launch() | |