Spaces:
Build error
Build error
File size: 6,367 Bytes
9d051b5 1376e14 03b8134 1376e14 9d051b5 1376e14 2923422 9d051b5 1376e14 2923422 6eac492 55ff40c 9d051b5 2923422 6eac492 2923422 6eac492 870b8a1 6eac492 870b8a1 6eac492 870b8a1 6eac492 1376e14 bb14bef 1376e14 bb14bef 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from diffusers import DiffusionPipeline
import requests
from PIL import Image
from io import BytesIO
import onnxruntime as ort
from huggingface_hub import hf_hub_download
# Initialize models
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
anime_model = ort.InferenceSession(anime_model_path)
photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
# Load labels for the anime model
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
with open(labels_path, 'r') as f:
anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header
def preprocess_image(image):
image = image.convert('RGB')
image = image.resize((448, 448), Image.LANCZOS)
image = np.array(image).astype(np.float32)
image = image[:, :, ::-1] # RGB -> BGR
image = np.transpose(image, (2, 0, 1)) # HWC -> CHW
image = image / 255.0
return image[np.newaxis, ...]
def transcribe_image(image, image_type, transcriber, booru_tags=None):
if image_type == "Anime":
input_image = preprocess_image(image)
input_name = anime_model.get_inputs()[0].name
output_name = anime_model.get_outputs()[0].name
probs = anime_model.run([output_name], {input_name: input_image})[0]
# Get top 50 tags
top_indices = probs[0].argsort()[-50:][::-1]
tags = [anime_labels[i] for i in top_indices]
else:
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(images=image, text=prompt, return_tensors="pt")
generated_ids = photo_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
tags = generated_text # Use generated text as the description
return ", ".join(tags)
def get_booru_image(booru, image_id):
if booru == "Gelbooru":
url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
elif booru == "Danbooru":
url = f"https://danbooru.donmai.us/posts/{image_id}.json"
elif booru == "rule34.xxx":
url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
else:
raise ValueError("Unsupported booru")
response = requests.get(url)
data = response.json()
# The exact structure of the response will vary depending on the booru
# You'll need to adjust this part based on each booru's API
image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url']
tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split()
img_response = requests.get(image_url)
img = Image.open(BytesIO(img_response.content))
return img, tags
def update_image(image_type, booru, image_id, uploaded_image):
if image_type == "Anime" and booru != "Upload":
image, booru_tags = get_booru_image(booru, image_id)
return image, gr.update(visible=True), booru_tags
elif uploaded_image is not None:
return uploaded_image, gr.update(visible=True), None
else:
return None, gr.update(visible=False), None
def on_image_type_change(image_type):
if image_type == "Anime":
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])
with gr.Blocks() as app:
gr.Markdown("# Image Transcription App")
with gr.Tab("Step 1: Image"):
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
with gr.Column(visible=False) as anime_options:
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
image_id = gr.Textbox(label="Image ID")
get_image_btn = gr.Button("Get image")
upload_btn = gr.UploadButton("Upload Image", visible=False)
image_display = gr.Image(label="Image to transcribe", visible=False)
booru_tags = gr.State(None)
transcribe_btn = gr.Button("Transcribe", visible=False)
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
with gr.Tab("Step 2: Transcribe"):
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
transcribe_image_display = gr.Image(label="Image to transcribe")
transcribe_btn_final = gr.Button("Transcribe")
tags_output = gr.Textbox(label="Transcribed tags")
image_type.change(on_image_type_change, inputs=[image_type],
outputs=[anime_options, upload_btn, transcriber])
get_image_btn.click(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
upload_btn.upload(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
def transcribe_and_update(image, image_type, transcriber, booru_tags):
tags = transcribe_image(image, image_type, transcriber, booru_tags)
return image, tags
transcribe_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_with_tags_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_btn_final.click(transcribe_image,
inputs=[transcribe_image_display, image_type, transcriber],
outputs=[tags_output])
app.launch() |