search-demo / utils /refine_metadata.py
rfmantoan
Fix image path
2bd9f7e
raw
history blame
3.99 kB
import torch
from PIL import Image
from transformers import BitsAndBytesConfig, pipeline
def ask_llava(metadata, image_path):
"""
Function to get the image description using LLaVA.
"""
# Unpack metadata
category = metadata.get('category', '')
subcategory = metadata.get('subcategory', '')
material = metadata.get('material', '')
gender = metadata.get('gender', '')
brand = metadata.get('brand', '')
name = metadata.get('name', '')
# Build the prompt for LLaVA
image = Image.open(image_path)
#prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, use your knowledge of fashion trends, styles, colors, gender preferences and brand information as well as your ability to describe, analyze and understand the image of the item to refine the metadata. Your goal is to improve the embedding process for models like FashionCLIP and MARGO-FashionSigLip by creating a more nuanced and detailed description that would boost the performance of the models. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name} - Description: {description} Refine and expand the metadata by incorporating information from the image and about the fashion item's style, cut, pattern, color scheme, brand, and any notable details. Include insights on current fashion trends and how the item fits within those trends. Be mindful that the it should be too around 77 tokens only, therefore, try to be concise and keep the description direct and useful for text to image and text to text search. Return the refined metadata as a single paragraph.\nASSISTANT:"""
prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, return an enhanced metadata structured in a single sentence with each field separated by a comma (do not include the field name, just use the same order). Keep it very concise and simple but make it more unterstandle for embedding models that will be used for search purposes. Also do a color analysis and add an extra field for the color of the item. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name}.\nASSISTANT:"""
# Generate description
outputs = img2text_pipeline(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
description = outputs[0]["generated_text"]
description = description.split("ASSISTANT: ")
return description[1]
def refine_metadata(catalog, column):
catalog[column] = ""
# Iterate over the DataFrame and process each item
for index, row in catalog.iterrows():
metadata = {
'category': row['L1'],
'subcategory': row['L2'],
'material': row['MaterialName'],
'gender': row['Gender'],
'brand': row['BrandName'],
'name': row['Name'],
'description': row['Description']
}
# Ensure the image ID is converted to a string
#image_path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg"
image_path = "/home/user/app/images/" + str(row["Id"]) + ".jpg"
# Generate the image description using LLaVA
refined_metadata = refine_metadata(metadata, image_path)
# Store results back in the DataFrame
catalog.at[index, column] = refined_metadata
return catalog
img2text_pipeline = None
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model_id = "llava-hf/llava-1.5-7b-hf"
if torch.cuda.is_available():
img2text_pipeline = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
else:
img2text_pipeline = pipeline("image-to-text", model=model_id)