import os import gradio as gr import torch from PIL import Image from transformers import MllamaForConditionalGeneration, AutoProcessor from peft import PeftModel from huggingface_hub import login import spaces import json # Login to Hugging Face if "HF_TOKEN" not in os.environ: raise ValueError("Please set the HF_TOKEN environment variable with your Hugging Face token") login(token=os.environ["HF_TOKEN"]) # Load model and processor (do this outside the inference function to avoid reloading) base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct" lora_weights_path = "taesiri/BungsBunny-LLama-3.2-11B-Vision-Instruct-Medium" processor = AutoProcessor.from_pretrained(base_model_path) model = MllamaForConditionalGeneration.from_pretrained( base_model_path, torch_dtype=torch.bfloat16, device_map="cuda", ) model = PeftModel.from_pretrained(model, lora_weights_path) model.tie_weights() @spaces.GPU def inference(image): # Prepare input messages = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in JSON"}]} ] input_text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device) # Run inference with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=2048) # Decode output result = processor.decode(output[0], skip_special_tokens=True) json_str = result.strip().split("assistant\n")[1].strip() try: # First JSON parse to handle escaped JSON string first_parse = json.loads(json_str) try: # Second JSON parse to get the actual JSON object json_object = json.loads(first_parse) # Return indented JSON string with 2 spaces return json.dumps(json_object, indent=2) except json.JSONDecodeError: # If second parse fails, return the result of first parse indented if isinstance(first_parse, (dict, list)): return json.dumps(first_parse, indent=2) return first_parse except json.JSONDecodeError: # If both JSON parses fail, return original string return json_str return None # In case of unexpected errors # Create Gradio interface using Blocks with gr.Blocks() as demo: gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo") with gr.Row(): # Container for the image takes full width with gr.Column(scale=1): image_input = gr.Image( type="pil", label="Upload Image", elem_id="large-image", height=500, # Increased height for larger display ) with gr.Row(): # Container for the text output takes full width with gr.Column(scale=1): text_output = gr.Textbox( label="Response", elem_id="response-text", lines=25, max_lines=10, ) # Button to trigger the analysis submit_btn = gr.Button("Analyze Image", variant="primary") submit_btn.click(fn=inference, inputs=[image_input], outputs=[text_output]) demo.launch()