taesiri's picture
Update
280da27
raw
history blame
1.62 kB
import gradio as gr
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from peft import PeftModel
# Load model and processor (do this outside the inference function to avoid reloading)
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
lora_weights_path = "taesiri/BunsBunny-LLama-3.2-11B-Vision-Instruct-DummyTask2"
processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model = PeftModel.from_pretrained(model, lora_weights_path)
def inference(image, question):
# Prepare input
messages = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device)
# Run inference
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
# Decode output
result = processor.decode(output[0], skip_special_tokens=True)
return result
# Create Gradio interface
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Enter your question")
],
outputs=gr.Textbox(label="Response"),
title="Image Analysis AI",
description="Upload an image and ask a question about it. The AI will analyze and respond.",
)
if __name__ == "__main__":
demo.launch()