import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM import spaces # Model configuration MID = "apple/FastVLM-0.5B" IMAGE_TOKEN_INDEX = -200 # Load model and tokenizer (will be loaded on first GPU allocation) tok = None model = None def load_model(): global tok, model if tok is None or model is None: print("Loading model...") tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) # Fallback: GPU if available, else CPU if torch.cuda.is_available(): device = "cuda" dtype = torch.float16 else: device = "cpu" dtype = torch.float32 # safer on CPU model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=dtype, device_map=device, # can be "cuda" or "cpu" trust_remote_code=True, ) print(f"Model loaded on {device.upper()} successfully!") return tok, model #@spaces.GPU(duration=60) def caption_image(image, custom_prompt=None): """ Generate a caption for the input image. Args: image: PIL Image from Gradio custom_prompt: Optional custom prompt to use instead of default Returns: Generated caption text """ if image is None: return "Please upload an image first." try: # Load model if not already loaded tok, model = load_model() # Convert image to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Use custom prompt or default prompt = custom_prompt if custom_prompt else "Describe this image in detail." # Build chat message messages = [ {"role": "user", "content": f"\n{prompt}"} ] # Render to string to place token correctly rendered = tok.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) # Split at image token pre, post = rendered.split("", 1) # Tokenize text around the image token pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids # Insert IMAGE token id at placeholder position img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) # Preprocess image using model's vision tower px = model.get_vision_tower().image_processor( images=image, return_tensors="pt" )["pixel_values"] px = px.to(model.device, dtype=model.dtype) # Generate caption with torch.no_grad(): out = model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=128, do_sample=False, # Deterministic generation temperature=1.0, ) # Decode and return the generated text generated_text = tok.decode(out[0], skip_special_tokens=True) # Extract only the assistant's response if "assistant" in generated_text: response = generated_text.split("assistant")[-1].strip() else: response = generated_text return response except Exception as e: return f"Error generating caption: {str(e)}" # Create Gradio interface with gr.Blocks(title="FastVLM Image Captioning") as demo: gr.Markdown( """ # 🖼️ FastVLM Image Captioning Upload an image to generate a detailed caption using Apple's FastVLM-0.5B model. You can use the default prompt or provide your own custom prompt. """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Image", elem_id="image-upload" ) custom_prompt = gr.Textbox( label="Custom Prompt (Optional)", placeholder="Leave empty for default: 'Describe this image in detail.'", lines=2 ) with gr.Row(): clear_btn = gr.ClearButton([image_input, custom_prompt]) generate_btn = gr.Button("Generate Caption", variant="primary") with gr.Column(): output = gr.Textbox( label="Generated Caption", lines=8, max_lines=100, show_copy_button=True ) # Event handlers generate_btn.click( fn=caption_image, inputs=[image_input, custom_prompt], outputs=output ) # Also generate on image upload if no custom prompt image_input.change( fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None, inputs=[image_input, custom_prompt], outputs=output ) gr.Markdown( """ --- **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B) **Note:** This Space uses ZeroGPU for dynamic GPU allocation. """ ) if __name__ == "__main__": demo.launch( share=False, show_error=True, server_name="0.0.0.0", server_port=7860 )