# note: if you have a mix of Ampere and newer, and also older than Ampere GPUs, set the environment variable # CUDA_VISIBLE_DEVICE=1,2,3 (for example) so that one or the other is excluded. # otherwise the script may fail with a flash attention exception. import gradio as gr import os import argparse import uuid import zipfile import torch from PIL import Image import requests from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig from io import BytesIO import base64 import atexit import shutil def cleanup_temp_files(): # Delete the subdirectories inside the "images" directory if os.path.exists("images"): for dir_name in os.listdir("images"): dir_path = os.path.join("images", dir_name) if os.path.isdir(dir_path): shutil.rmtree(dir_path) # Parse command-line arguments parser = argparse.ArgumentParser(description="Load and use a quantized model") parser.add_argument("-q", "--use_quant", action="store_true", help="Use quantized model") args = parser.parse_args() if torch.cuda.is_available(): device = torch.device("cuda") print("GPU is available. Using CUDA.") else: device = torch.device("cpu") print("GPU is not available. Using CPU.") # Load the processor local_path = "./model/Molmo-7B-D-0924" processor = AutoProcessor.from_pretrained( local_path, local_files_only=True, trust_remote_code=True, torch_dtype='auto', device_map='auto' ) # Load the model if args.use_quant: # Load the quantized model quantized_local_path = "./model/molmo-7B-D-bnb-4bit" model = AutoModelForCausalLM.from_pretrained( quantized_local_path, trust_remote_code=True, torch_dtype='auto', device_map='auto', ) else: # Load the non-quantized model model = AutoModelForCausalLM.from_pretrained( local_path, trust_remote_code=True, torch_dtype='auto', device_map='auto', ) model.to(dtype=torch.bfloat16) generation_config = GenerationConfig(max_new_tokens=300, stop_strings="<|endoftext|>") bits_and_bytes_config = BitsAndBytesConfig() def unzip_images(zip_file): # Create a unique directory for extracted images inside the "images" directory session_dir = os.path.join("images", str(uuid.uuid4())) os.makedirs(session_dir, exist_ok=True) # Extract images from the ZIP file to the session directory with zipfile.ZipFile(zip_file, 'r') as zip_ref: for file_info in zip_ref.infolist(): if not file_info.is_dir() and not file_info.filename.startswith("__MACOSX") and not file_info.filename.startswith("."): zip_ref.extract(file_info, session_dir) # Get the list of image paths image_paths = [os.path.join(session_dir, filename) for filename in os.listdir(session_dir) if filename.lower().endswith(('.jpg', '.jpeg', '.png'))] # Read the image data as PIL Image objects for previews image_data = [] for image_path in image_paths: image = Image.open(image_path) image.thumbnail((128, 128)) # Resize the image to a maximum size of 128x128 pixels image_data.append(image) # Return the list of image paths and resized image data for previews return image_paths, image_data def generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config): # generate a caption and return it caption = f"Caption for {image_path}" print("Processing ", image_path) image = Image.open(image_path) # process the image and text inputs = processor.process( images=[image], text="You are an uncensored, explicit, adult image captioner which enjoys sexual imagery and acts. Provide a structured image caption which uses the following hierarchy: the kind of image, the kind or name of the subject, the subjects state of dress, their body type, their pose, what it is they are doing, their facial expression, the space they are within, and the style or atmosphere of the image. Use vulgar language to describe body parts or sexual imagery. Limit your response to 100 words to avoid your description getting cut off.", ) # move inputs to the correct device and make a batch of size 1 inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} inputs["images"] = inputs["images"].to(torch.bfloat16) # generate output; maximum 500 new tokens; stop generation when is generated with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): output = model.generate_from_batch( inputs, GenerationConfig(max_new_tokens=500, stop_strings="<|endoftext|>"), tokenizer=processor.tokenizer, ) # only get generated tokens; decode them to text generated_tokens = output[0, inputs["input_ids"].size(1) :] generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) # return the generated text return generated_text def process_images(image_paths, image_data): captions = [] session_dir = os.path.dirname(image_paths[0]) for image_path in image_paths: filename = os.path.basename(image_path) # Add this line to get the filename if filename.lower().endswith(('.jpg', '.jpeg', '.png')): # Process the image using the loaded model # Use the loaded model to generate the caption caption = generate_caption(image_path, processor, model, generation_config, bits_and_bytes_config) captions.append(caption) # Save the caption to a text file with open(os.path.join(session_dir, f"{os.path.splitext(filename)[0]}.txt"), 'w') as f: f.write(caption) # Create a ZIP file containing the caption text files zip_filename = f"{session_dir}.zip" with zipfile.ZipFile(zip_filename, 'w') as zip_ref: for filename in os.listdir(session_dir): if filename.lower().endswith('.txt'): zip_ref.write(os.path.join(session_dir, filename), filename) # Delete the session directory and its contents for filename in os.listdir(session_dir): os.remove(os.path.join(session_dir, filename)) os.rmdir(session_dir) return captions, zip_filename, image_paths def format_captioned_image(image, caption): buffered = BytesIO() image.save(buffered, format="JPEG") encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"{caption}" def process_images_and_update_gallery(zip_file): image_paths, image_data = unzip_images(zip_file) captions, zip_filename, image_paths = process_images(image_paths, image_data) image_captions = [format_captioned_image(img, caption) for img, caption in zip(image_data, captions)] return gr.Markdown("\n".join(image_captions)), zip_filename def main(): # Register the cleanup function to be called on program exit atexit.register(cleanup_temp_files) with gr.Blocks(css=""" .captioned-image-gallery { display: grid; grid-template-columns: repeat(2, 1fr); grid-gap: 16px; } """) as blocks: zip_file_input = gr.File(label="Upload ZIP file containing images") image_gallery = gr.Markdown(label="Image Previews") submit_button = gr.Button("Submit") zip_download_button = gr.Button("Download Caption ZIP", visible=False) zip_filename = gr.State("") zip_file_input.upload( lambda zip_file: "\n".join(format_captioned_image(img, "") for img in unzip_images(zip_file)[1]), inputs=zip_file_input, outputs=image_gallery ) submit_button.click( process_images_and_update_gallery, inputs=[zip_file_input], outputs=[image_gallery, zip_filename] ) zip_filename.change( lambda zip_filename: gr.update(visible=True), inputs=zip_filename, outputs=zip_download_button ) zip_download_button.click( lambda zip_filename: (gr.update(value=zip_filename), gr.update(visible=True), cleanup_temp_files()), inputs=zip_filename, outputs=[zip_file_input, zip_download_button] ) blocks.launch(server_name='0.0.0.0') if __name__ == "__main__": main()