Spaces:
Build error
Build error
File size: 3,826 Bytes
ca2a4e0 ad35dba 8a5e693 d2a2d86 ca2a4e0 ad35dba 8a5e693 fc250c3 e9b25d5 ad35dba e9b25d5 ad35dba e9b25d5 ca2a4e0 e9b25d5 ad35dba ca2a4e0 fc250c3 ad35dba 8a5e693 fc250c3 e9b25d5 ca2a4e0 8a5e693 ca2a4e0 fc250c3 8a5e693 ca2a4e0 e9b25d5 ca2a4e0 e9b25d5 ca2a4e0 e9b25d5 ca2a4e0 e9b25d5 ca2a4e0 e9b25d5 ca2a4e0 e9b25d5 ca2a4e0 8a5e693 fc250c3 ad35dba 8a5e693 ca2a4e0 8a5e693 ad35dba ca2a4e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import torch
from datasets import load_dataset, Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageCaptioner:
"""
A class for generating captions for images using a pre-trained model.
Args:
dataset (str): The path to the dataset.
processor (str): The pre-trained processor model to use for image processing.
model (str): The pre-trained model to use for caption generation.
prompt (str): The conditioning prompt to use for caption generation.
Attributes:
dataset: The loaded dataset.
processor: The pre-trained processor model.
model: The pre-trained caption generation model.
prompt: The conditioning prompt for generating captions.
Methods:
process_dataset: Preprocesses the dataset.
generate_caption: Generates a caption for a single image.
generate_captions: Generates captions for all images in the dataset.
"""
def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
self.dataset = load_dataset(dataset, split="test")
self.dataset = self.dataset.select(range(10000)) # For demonstration purposes
self.processor = BlipProcessor.from_pretrained(processor)
self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
self.prompt = prompt
def process_dataset(self):
"""
Preprocesses the dataset by renaming the image column and removing unwanted columns.
Returns:
The preprocessed dataset.
"""
# Check if 'image' column exists, otherwise use 'pixel_values' if it exists
image_column = "image" if "image" in self.dataset.column_names else "pixel_values"
self.dataset = self.dataset.rename_column(image_column, "image")
if "label" in self.dataset.column_names:
self.dataset = self.dataset.remove_columns(["label"])
# Add an empty 'text' column for captions if it doesn't exist
if "text" not in self.dataset.column_names:
new_column = [""] * len(self.dataset)
self.dataset = self.dataset.add_column("text", new_column)
return self.dataset
def generate_caption(self, example):
"""
Generates a caption for a single image.
Args:
example (dict): A dictionary containing the image data.
Returns:
dict: The dictionary with the generated caption.
"""
image = example["image"].convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(device)
prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
outputs = self.model.generate(**inputs, **prompt_inputs)
blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
example["text"] = blip_caption
return example
def generate_captions(self):
"""
Generates captions for all images in the dataset.
Returns:
Dataset: The dataset with generated captions.
"""
self.dataset = self.process_dataset()
self.dataset = self.dataset.map(self.generate_caption, batched=False)
return self.dataset
# Initialize ImageCaptioner
ic = ImageCaptioner(
dataset=PRODUCTS_10k_DATASET,
processor=CAPTIONING_MODEL_NAME,
model=CAPTIONING_MODEL_NAME,
prompt='Commercial photography of'
)
# Generate captions for the dataset
products10k_dataset = ic.generate_captions()
# Save the dataset to the hub
products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")
|