Edit model card

language:

  • fa datasets:
  • BaSalam/vision-catalogs-llava-format-v3 pipeline_tag: image-text-to-text

LLaVA Model Card

Model details

This model is "llava-hf/llava-1.5-7b-hf", fine-tuned on "Basalam product" data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.

How to use the model

Below is an example script to run generation in float16 precision on a GPU device:

import requests
from PIL import Image
import torch
import json

from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

def prompt_formatter(entity):
    json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
    final_prompt = f"""برای محصول داده شده، ویژگی‌های تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
    return final_prompt

prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": prompt},
          {"type": "image"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

output = model.generate(**inputs, max_new_tokens=384, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
[
  {
    "attributes": {
      "نوع": [
        "تیشرت مردانه"
      ],
      "طرح چاپی": [
        "MVP"
      ],
      "رنگ": [
        "زرد",
        "آبی",
        "سفید",
        "مشکی",
        "کرم",
        "سبز"
      ],
      "سایز": [
        "L",
        "XL",
        "2XL",
        "3XL"
      ]
    }
  }
]

Model optimization

4-bit quantization through bitsandbytes library

First make sure to install bitsandbytes, pip install bitsandbytes and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   load_in_4bit=True
)

Use Flash-Attention 2 to further speed-up generation

First make sure to install flash-attn. Refer to the original repository of Flash Attention regarding that package installation. Simply change the snippet above with:

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   use_flash_attention_2=True
).to(0)
Downloads last month
46
Safetensors
Model size
7.06B params
Tensor type
FP16
·
Inference Examples
Inference API (serverless) does not yet support transformers models for this pipeline type.

Model tree for BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0

Finetuned
(31)
this model

Collection including BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0