Spaces:
Running
Running
File size: 17,653 Bytes
fadb9e6 2ed398f fadb9e6 c27c049 fc2fb4b 68e4fda 2ed398f fadb9e6 773b499 fadb9e6 773b499 fadb9e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 |
import os
import base64
import io
import sqlite3
import torch
import gradio as gr
import pandas as pd
from PIL import Image
import requests
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import traceback
from tqdm import tqdm
import zipfile
# Define constants for vikhyatk/moondream2 model
MOON_DREAM_MODEL_ID = "vikhyatk/moondream2"
MOON_DREAM_REVISION = "2024-08-26"
# Define constants for the Qwen2-VL models
QWEN2_VL_MODELS = [
'Qwen/Qwen2-VL-7B-Instruct',
'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4',
'OpenGVLab/InternVL2-1B',
'Qwen/Qwen2-VL-72B',
]
# List of models to use (combining unique entries from available models and QWEN2_VL_MODELS)
available_models = [
*QWEN2_VL_MODELS, # Expands the QWEN2_VL_MODELS list into the available_models
'microsoft/Phi-3-vision-128k-instruct',
'vikhyatk/moondream2'
]
# List of available Hugging Face datasets
dataset_options = [
"gokaygokay/panorama_hdr_dataset",
"OpenGVLab/CRPE"
]
# List of text prompts to use
text_prompts = [
"Provide a detailed description of the image contents, including all visible objects, people, activities, and extract any text present within the image using Optical Character Recognition (OCR). Organize the extracted text in a structured table format with columns for original text, its translation into English, and the language it is written in.",
"Offer a thorough description of all elements within the image, from objects to individuals and their activities. Ensure any legible text seen in the image is extracted using Optical Character Recognition (OCR). Provide an accurate narrative that encapsulates the full content of the image.",
"Create a four-sentence caption for the image. Start by specifying the style and type, such as painting, photograph, or digital art. In the next sentences, detail the contents and the composition clearly and concisely. Use language suited for prompting a text-to-image model, separating descriptive terms with commas instead of 'or'. Keep the description direct, avoiding interpretive phrases or abstract expressions",
]
# SQLite setup
# def init_db():
# conn = sqlite3.connect('image_outputs.db')
# cursor = conn.cursor()
# cursor.execute('''
# CREATE TABLE IF NOT EXISTS image_outputs (
# id INTEGER PRIMARY KEY AUTOINCREMENT,
# image BLOB,
# prompt TEXT,
# output TEXT,
# model_name TEXT
# )
# ''')
# conn.commit()
# conn.close()
def image_to_binary(image_path):
with open(image_path, 'rb') as file:
return file.read()
# def store_in_db(image_path, prompt, output, model_name):
# conn = sqlite3.connect('image_outputs.db')
# cursor = conn.cursor()
# image_blob = image_to_binary(image_path)
# cursor.execute('''
# INSERT INTO image_outputs (image, prompt, output, model_name)
# VALUES (?, ?, ?, ?)
# ''', (image_blob, prompt, output, model_name))
# conn.commit()
# conn.close()
# Function to encode an image to base64 for HTML display
def encode_image(image):
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_str = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
return f'<img src="data:image/png;base64,{img_str}" style="max-width:500px;"/>'
# Function to load and display images from the panorama_hdr_dataset
def load_dataset_images(dataset_name, num_images):
try:
dataset = load_dataset(dataset_name, split='train')
images = []
for i, item in enumerate(dataset[:num_images]):
if 'image' in item:
img = item['image']
print (type(img))
encoded_img = encode_image(img)
metadata = f"Width: {img.width}, Height: {img.height}"
if 'hdr' in item:
metadata += f", HDR: {item['hdr']}"
images.append(f"<div style='display: inline-block; margin: 10px; text-align: center;'><h3>Image {i+1}</h3>{encoded_img}<p>{metadata}</p></div>")
if not images:
return "No images could be loaded from this dataset. Please check the dataset structure."
return "".join(images)
except Exception as e:
print(f"Error loading dataset: {e}")
traceback.print_exc()
# Function to generate output
def generate_output(model, processor, prompt, image, model_name, device):
try:
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes = image_bytes.getvalue()
if model_name in QWEN2_VL_MODELS:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_bytes},
{"type": "text", "text": prompt},
]
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=[Image.open(io.BytesIO(image_bytes))],
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids)]
response_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return response_text
elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
messages = [{"role": "user", "content": f"<|image_1|>\n{prompt}"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=1024)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response_text = processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
return response_text
elif model_name == 'vikhyatk/moondream2':
tokenizer = AutoTokenizer.from_pretrained(MOON_DREAM_MODEL_ID, revision=MOON_DREAM_REVISION)
enc_image = model.encode_image(image)
response_text = model.answer_question(enc_image, prompt, tokenizer)
return response_text
except Exception as e:
return f"Error during generation with model {model_name}: {e}"
# Function to list and encode images from a directory
def list_images(directory_path):
images = []
for filename in os.listdir(directory_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(directory_path, filename)
encoded_img = encode_image(image_path)
images.append({
"filename": filename,
"image": encoded_img
})
return images
# Function to extract images from a ZIP file
# Function to extract images from a ZIP file
def extract_images_from_zip(zip_file):
images = []
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
for file_info in zip_ref.infolist():
if file_info.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
with zip_ref.open(file_info) as file:
try:
img = Image.open(file)
img = img.convert("RGB") # Ensure the image is in RGB mode
encoded_img = img
images.append({
"filename": file_info.filename,
"image": encoded_img
})
except Exception as e:
print(f"Error opening image {file_info.filename}: {e}")
return images
# Gradio interface function for running inference
def run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
data = []
torch_dtype_value = torch.float16 if torch_dtype == "torch.float16" else torch.float32
device_map_value = "cuda" if torch.cuda.is_available() else "cpu" if device_map == "auto" else device_map
model_processors = {}
for model_name in model_names:
try:
if model_name in QWEN2_VL_MODELS:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch_dtype_value,
device_map=device_map_value
).eval()
processor = AutoProcessor.from_pretrained(model_name)
elif model_name == 'microsoft/Phi-3-vision-128k-instruct':
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map_value,
torch_dtype=torch_dtype_value,
trust_remote_code=trust_remote_code,
use_flash_attn=use_flash_attn
).eval()
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
elif model_name == 'vikhyatk/moondream2':
model = AutoModelForCausalLM.from_pretrained(
MOON_DREAM_MODEL_ID,
trust_remote_code=True,
revision=MOON_DREAM_REVISION
).eval()
processor = None # No processor needed for this model
model_processors[model_name] = (model, processor)
except Exception as e:
print(f"Error loading model {model_name}: {e}")
try:
# Load images from the ZIP file if use_zip is True
if use_zip:
images = extract_images_from_zip(zip_file)
print ("Number of images in zip:" , len(images))
for img in tqdm(images):
try:
img_data = img['image']
if not isinstance(img_data, str):
# Convert the Image object to a base64-encoded string
img_buffer = io.BytesIO()
img['image'].save(img_buffer, format="PNG")
img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
img_data=f'<img src="data:image/png;base64,{img_data}" style="max-width:500px;"/>'
row_data = {"Image": img_data} # Assuming encode_image is defined elsewhere
for model_name in model_names:
if model_name in model_processors:
model, processor = model_processors[model_name]
for prompt in prompts:
try:
# Ensure image is defined
image = img['image']
response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
row_data[f"{model_name}_Response_{prompt}"] = response_text
except Exception as e:
row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"
traceback.print_exc()
data.append(row_data)
except Exception as e:
print(f"Error processing image {img['filename']}: {e}")
traceback.print_exc()
# Load the dataset if use_zip is False
else:
dataset = load_dataset(dataset_input, split='train')
for i in tqdm(range(num_images_input)):
if dataset_input == "OpenGVLab/CRPE":
image = dataset[i]['image']
elif dataset_input == "gokaygokay/panorama_hdr_dataset":
image = dataset[i]['png_image']
else:
image = dataset[i]['image']
encoded_img = encode_image(image)
row_data = {"Image": encoded_img}
for model_name in model_names:
if model_name in model_processors:
model, processor = model_processors[model_name]
for prompt in prompts:
try:
response_text = generate_output(model, processor, prompt, image, model_name, device_map_value)
row_data[f"{model_name}_Response_{prompt}"] = response_text
except Exception as e:
row_data[f"{model_name}_Response_{prompt}"] = f"Error during generation with model {model_name}: {e}"
data.append(row_data)
except Exception as e:
print(f"Error loading dataset: {e}")
traceback.print_exc()
return pd.DataFrame(data).to_html(escape=False)
def show_image(image):
return image # Simply display the selected image
# Gradio UI setup
def create_gradio_interface():
css = """
#output {
height: 500px;
overflow: auto;
}
"""
with gr.Blocks(css=css) as demo:
# Title
gr.Markdown("# VLM-Image-Analysis: A Vision-and-Language Modeling Framework.")
gr.Markdown("""
- Handle a batch of images from a ZIP file OR
- Processes images from an HF DB
- Compatible with png, jpg, jpeg, and webp formats
- Compatibility with various AI models: Qwen2-VL-7B-Instruct, Qwen2-VL-2B-Instruct-GPTQ-Int4, InternVL2-1B, Qwen2-VL-72B, /Phi-3-vision-128k-instruct and moondream2""")
# image_path = os.path.abspath("static/image.jpg")
# gr.Image(value=image_path, label="HF Image", width=300, height=300)
# init_image = gr.Image(label="Selected Image", type="filepath")
# # Use gr.Examples to showcase a set of example images
# gr.Examples(
# examples=[
# ["static/image.jpg"],
# ],
# inputs=[init_image],
# label="Example Images"
# )
# init_image.change(show_image, inputs=init_image, outputs=init_image)
with gr.Tab("VLM model and Dataset selection"):
gr.Markdown("### Dataset Selection: HF or load from a ZIP file.")
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
# with gr.Column():
use_zip_input = gr.Checkbox(label="Use ZIP File", value=False)
zip_file_input = gr.File(label="Upload ZIP File of Images", file_types=[".zip"])
dataset_input = gr.Dropdown(choices=dataset_options, label="Select Dataset", value=dataset_options[1], visible=True)
num_images_input = gr.Radio(choices=[1, 5, 20], label="Number of Images", value=5)
gr.Markdown("### VLM Model Selection")
with gr.Row():
with gr.Column():
models_input = gr.CheckboxGroup(choices=available_models, label="Select Models", value=available_models[4])
prompts_input = gr.CheckboxGroup(choices=text_prompts, label="Select Prompts", value=text_prompts[2])
submit_btn = gr.Button("Run Inference")
with gr.Row():
output_display = gr.HTML(label="Results")
with gr.Tab("GPU Device Settings"):
device_map_input = gr.Radio(choices=["auto", "cpu", "cuda"], label="Device Map", value="auto")
torch_dtype_input = gr.Radio(choices=["torch.float16", "torch.float32"], label="Torch Dtype", value="torch.float16")
trust_remote_code_input = gr.Checkbox(label="Trust Remote Code", value=True)
use_flash_attn = gr.Checkbox(label="Use flash-attn 2 (Ampere GPUs or newer.)", value=False)
def run_inference_wrapper(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file):
return run_inference(model_names, dataset_input, num_images_input, prompts, device_map, torch_dtype, trust_remote_code,use_flash_attn, use_zip, zip_file)
def toggle_dataset_visibility(use_zip):
return gr.update(visible=not use_zip)
submit_btn.click(
fn=run_inference_wrapper,
inputs=[models_input, dataset_input, num_images_input, prompts_input, device_map_input, torch_dtype_input, trust_remote_code_input,use_flash_attn, use_zip_input, zip_file_input],
outputs=output_display
)
use_zip_input.change(
fn=toggle_dataset_visibility,
inputs=use_zip_input,
outputs=dataset_input
)
demo.launch(debug=True, share=False)
if __name__ == "__main__":
create_gradio_interface() |