Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from transformers import BitsAndBytesConfig, pipeline | |
def ask_llava(metadata, image_path): | |
""" | |
Function to get the image description using LLaVA. | |
""" | |
# Unpack metadata | |
category = metadata.get('category', '') | |
subcategory = metadata.get('subcategory', '') | |
material = metadata.get('material', '') | |
gender = metadata.get('gender', '') | |
brand = metadata.get('brand', '') | |
name = metadata.get('name', '') | |
# Build the prompt for LLaVA | |
image = Image.open(image_path) | |
#prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, use your knowledge of fashion trends, styles, colors, gender preferences and brand information as well as your ability to describe, analyze and understand the image of the item to refine the metadata. Your goal is to improve the embedding process for models like FashionCLIP and MARGO-FashionSigLip by creating a more nuanced and detailed description that would boost the performance of the models. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name} - Description: {description} Refine and expand the metadata by incorporating information from the image and about the fashion item's style, cut, pattern, color scheme, brand, and any notable details. Include insights on current fashion trends and how the item fits within those trends. Be mindful that the it should be too around 77 tokens only, therefore, try to be concise and keep the description direct and useful for text to image and text to text search. Return the refined metadata as a single paragraph.\nASSISTANT:""" | |
prompt = f"""USER: <image>\nYou are an expert in fashion and visual analysis. Given the following metadata and an image, return an enhanced metadata structured in a single sentence with each field separated by a comma (do not include the field name, just use the same order). Keep it very concise and simple but make it more unterstandle for embedding models that will be used for search purposes. Also do a color analysis and add an extra field for the color of the item. Metadata Provided: - Category: {category} - Subcategory: {subcategory} - Material: {material} - Gender: {gender} - Brand: {brand} - Name: {name}.\nASSISTANT:""" | |
# Generate description | |
outputs = img2text_pipeline(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200}) | |
description = outputs[0]["generated_text"] | |
description = description.split("ASSISTANT: ") | |
return description[1] | |
def refine_metadata(catalog, column): | |
catalog[column] = "" | |
# Iterate over the DataFrame and process each item | |
for index, row in catalog.iterrows(): | |
metadata = { | |
'category': row['L1'], | |
'subcategory': row['L2'], | |
'material': row['MaterialName'], | |
'gender': row['Gender'], | |
'brand': row['BrandName'], | |
'name': row['Name'], | |
'description': row['Description'] | |
} | |
# Ensure the image ID is converted to a string | |
#image_path = "/content/drive/MyDrive/images/" + str(row["Id"]) + ".jpg" | |
image_path = "/home/user/app/images/" + str(row["Id"]) + ".jpg" | |
# Generate the image description using LLaVA | |
refined_metadata = refine_metadata(metadata, image_path) | |
# Store results back in the DataFrame | |
catalog.at[index, column] = refined_metadata | |
return catalog | |
img2text_pipeline = None | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
model_id = "llava-hf/llava-1.5-7b-hf" | |
if torch.cuda.is_available(): | |
img2text_pipeline = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
else: | |
img2text_pipeline = pipeline("image-to-text", model=model_id) |