File size: 7,397 Bytes
b53722c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
from transformers.image_utils import load_image
from pathlib import Path
import time
model_name_or_path = "Minthy/ToriiGate-v0.3"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables to store model and processor
global_model = None
global_processor = None
def load_model():
global global_model, global_processor
if global_model is None:
print("Loading model for the first time...")
# Always use 4-bit quantization for 16GB VRAM
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
global_model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
quantization_config=nf4_config,
).to(DEVICE)
global_processor = AutoProcessor.from_pretrained(model_name_or_path)
return global_model, global_processor
def generate_caption(image, description_type, booru_tags=""):
model, processor = load_model()
if description_type == "JSON-like":
user_prompt = "Describe the picture in structuted json-like format."
elif description_type == "Detailed":
user_prompt = "Give a long and detailed description of the picture."
else:
user_prompt = "Describe the picture briefly."
if booru_tags:
user_prompt += ' Also here are booru tags for better understanding of the picture, you can use them as reference.'
user_prompt += f' <tags>\n{booru_tags}\n</tags>'
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored. Help user with his task."}
]
},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": user_prompt}
]
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
caption = generated_texts[0].split('Assistant: ')[1]
return caption
def process_batch(files, description_type, booru_tags="", progress=gr.Progress(track_tqdm=True)):
results = []
captions_text = ""
total_files = len(files)
start_time = time.time()
for idx, file in enumerate(files, 1):
# Calculate progress statistics
elapsed_time = time.time() - start_time
images_per_second = idx / elapsed_time if elapsed_time > 0 else 0
estimated_total = (elapsed_time / idx) * total_files if idx > 0 else 0
remaining_time = estimated_total - elapsed_time
try:
image = load_image(file.name)
caption = generate_caption(image, description_type, booru_tags)
# Add caption to the running text with a blank line separator
if captions_text:
captions_text += "\n\n" # Add blank line between captions
captions_text += caption
# Update the results list for the dataframe
results.append((Path(file.name).name, caption))
# Update progress
progress_status = f"Processing: {idx}/{total_files} images | Speed: {images_per_second:.2f} img/s | Remaining: {remaining_time/60:.1f} min"
# Yield progress status and captions separately
yield results, progress_status, captions_text
except Exception as e:
error_msg = f"Error processing {Path(file.name).name}: {str(e)}"
print(error_msg)
if captions_text:
captions_text += "\n\n"
captions_text += f"[ERROR] {error_msg}"
yield results, progress_status, captions_text
# Final update
yield results, "✅ Processing complete!", captions_text
# Gradio Interface
with gr.Blocks(title="ToriiGate Image Captioner") as demo:
gr.Markdown("# ToriiGate Image Captioner")
gr.Markdown("Generate captions for anime images using ToriiGate-v0.3 model (4-bit quantized)")
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
description_type = gr.Radio(
choices=["JSON-like", "Detailed", "Brief"],
value="JSON-like",
label="Description Type"
)
booru_tags = gr.Textbox(
lines=3,
label="Booru Tags (Optional)",
placeholder="Enter comma-separated booru tags..."
)
submit_btn = gr.Button("Generate Caption")
with gr.Column():
output_text = gr.Textbox(label="Generated Caption", lines=10)
submit_btn.click(
generate_caption,
inputs=[input_image, description_type, booru_tags],
outputs=output_text
)
with gr.Tab("Batch Processing"):
with gr.Row():
with gr.Column():
input_files = gr.File(file_count="multiple", label="Input Images")
batch_description_type = gr.Radio(
choices=["JSON-like", "Detailed", "Brief"],
value="JSON-like",
label="Description Type"
)
batch_booru_tags = gr.Textbox(
lines=3,
label="Booru Tags (Optional)",
placeholder="Enter comma-separated booru tags..."
)
batch_submit_btn = gr.Button("Process Batch")
with gr.Column():
progress_status = gr.Textbox(
label="Progress",
lines=2,
show_copy_button=False
)
output_text_batch = gr.Textbox(
label="Generated Captions",
lines=25,
show_copy_button=True
)
output_gallery = gr.Dataframe(
headers=["Filename", "Caption"],
label="Generated Captions (Table View)",
visible=False # Hide the dataframe
)
batch_submit_btn.click(
process_batch,
inputs=[input_files, batch_description_type, batch_booru_tags],
outputs=[output_gallery, progress_status, output_text_batch]
)
if __name__ == "__main__":
# Load model at startup
load_model()
demo.launch(share=True) |