Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

LLaVA-LoRA Adapter

This is a LoRA adapter for the LLaVA model, fine-tuned for spatial description tasks.

Base Model

This adapter is trained on top of llava-hf/llava-1.5-7b-hf.

Training

The model was fine-tuned using LoRA with the following configuration:

  • Rank: 8
  • Alpha: 32
  • Target modules: q_proj, v_proj, k_proj
  • Dataset: PersReFex validation set

Usage

from peft import PeftModel
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

# Load base model
base_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    torch_dtype=torch.bfloat16
).to('cuda')
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "ZinengTang/llava-lora-spatial"
)

from PIL import Image
init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": init_prompt_instruct},
            {"type": "image"},  # This will be replaced with the actual image
        ],
    },
]
speaker_image = Image.open('your_image_path')
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# print(prompt)
# Process the input image and prompt
inputs = processor(
    images=speaker_image,
    text=prompt,
    return_tensors="pt",
    max_length=256,
).to('cuda')

with torch.no_grad():
    generated = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pixel_values=inputs["pixel_values"],
        max_length=512,
        num_beams=1,
        do_sample=True,
        temperature=0.7
    )
    generated_message = processor.batch_decode(
        generated, 
        skip_special_tokens=True
    )
    print(generated_message)
    generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .