Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from diffusers import DiffusionPipeline | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
# Initialize models | |
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx") | |
anime_model = ort.InferenceSession(anime_model_path) | |
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval") | |
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval") | |
# Load labels for the anime model | |
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv") | |
with open(labels_path, 'r') as f: | |
labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header | |
def preprocess_image(image): | |
image = image.convert('RGB') | |
image = image.resize((448, 448), Image.LANCZOS) | |
image = np.array(image).astype(np.float32) | |
image = image[:, :, ::-1] # RGB -> BGR | |
image = np.transpose(image, (2, 0, 1)) # HWC -> CHW | |
image = image / 255.0 | |
return image[np.newaxis, ...] | |
def get_booru_image(booru, image_id): | |
if booru == "Gelbooru": | |
url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" | |
elif booru == "Danbooru": | |
url = f"https://danbooru.donmai.us/posts/{image_id}.json" | |
elif booru == "rule34.xxx": | |
url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}" | |
else: | |
raise ValueError("Unsupported booru") | |
response = requests.get(url) | |
data = response.json() | |
# The exact structure of the response will vary depending on the booru | |
# You'll need to adjust this part based on each booru's API | |
image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url'] | |
tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split() | |
img_response = requests.get(image_url) | |
img = Image.open(BytesIO(img_response.content)) | |
return img, tags | |
def transcribe_image(image, image_type, transcriber, booru_tags=None): | |
if image_type == "Anime": | |
input_image = preprocess_image(image) | |
input_name = anime_model.get_inputs()[0].name | |
output_name = anime_model.get_outputs()[0].name | |
probs = anime_model.run([output_name], {input_name: input_image})[0] | |
# Get top 50 tags | |
top_indices = probs[0].argsort()[-50:][::-1] | |
tags = [labels[i] for i in top_indices] | |
else: | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = photo_model(**inputs) | |
tags = outputs.logits.topk(50).indices.squeeze().tolist() | |
tags = [processor.config.id2label[t] for t in tags] | |
if booru_tags: | |
tags = list(set(tags + booru_tags)) | |
return ", ".join(tags) | |
def update_image(image_type, booru, image_id, uploaded_image): | |
if image_type == "Anime" and booru != "Upload": | |
image, booru_tags = get_booru_image(booru, image_id) | |
return image, gr.update(visible=True), booru_tags | |
elif uploaded_image is not None: | |
return uploaded_image, gr.update(visible=True), None | |
else: | |
return None, gr.update(visible=False), None | |
def on_image_type_change(image_type): | |
if image_type == "Anime": | |
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) | |
else: | |
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) | |
with gr.Blocks() as app: | |
gr.Markdown("# Image Transcription App") | |
with gr.Tab("Step 1: Image"): | |
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") | |
with gr.Column(visible=False) as anime_options: | |
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") | |
image_id = gr.Textbox(label="Image ID") | |
get_image_btn = gr.Button("Get image") | |
upload_btn = gr.UploadButton("Upload Image", visible=False) | |
image_display = gr.Image(label="Image to transcribe", visible=False) | |
booru_tags = gr.State(None) | |
transcribe_btn = gr.Button("Transcribe", visible=False) | |
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) | |
with gr.Tab("Step 2: Transcribe"): | |
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") | |
transcribe_image_display = gr.Image(label="Image to transcribe") | |
transcribe_btn_final = gr.Button("Transcribe") | |
tags_output = gr.Textbox(label="Transcribed tags") | |
image_type.change(on_image_type_change, inputs=[image_type], | |
outputs=[anime_options, upload_btn, transcriber]) | |
get_image_btn.click(update_image, | |
inputs=[image_type, booru, image_id, upload_btn], | |
outputs=[image_display, transcribe_btn, booru_tags]) | |
upload_btn.upload(update_image, | |
inputs=[image_type, booru, image_id, upload_btn], | |
outputs=[image_display, transcribe_btn, booru_tags]) | |
def transcribe_and_update(image, image_type, transcriber, booru_tags): | |
tags = transcribe_image(image, image_type, transcriber, booru_tags) | |
return image, tags | |
transcribe_btn.click(transcribe_and_update, | |
inputs=[image_display, image_type, transcriber, booru_tags], | |
outputs=[transcribe_image_display, tags_output]) | |
transcribe_with_tags_btn.click(transcribe_and_update, | |
inputs=[image_display, image_type, transcriber, booru_tags], | |
outputs=[transcribe_image_display, tags_output]) | |
transcribe_btn_final.click(transcribe_image, | |
inputs=[transcribe_image_display, image_type, transcriber], | |
outputs=[tags_output]) | |
app.launch() |