File size: 3,910 Bytes
ca2a4e0
ad35dba
8a5e693
d2a2d86
 
ca2a4e0
 
ad35dba
8a5e693
fc250c3
 
e9b25d5
 
 
 
 
 
 
ad35dba
e9b25d5
 
 
 
 
ad35dba
e9b25d5
 
 
ca2a4e0
 
e9b25d5
 
ad35dba
ca2a4e0
 
fc250c3
 
ad35dba
8a5e693
fc250c3
e9b25d5
 
 
 
 
 
ca2a4e0
 
 
 
8a5e693
 
ca2a4e0
 
 
 
 
 
fc250c3
8a5e693
ca2a4e0
e9b25d5
ca2a4e0
e9b25d5
ca2a4e0
 
e9b25d5
ca2a4e0
 
e9b25d5
ca2a4e0
 
 
 
 
 
 
e9b25d5
ca2a4e0
 
 
e9b25d5
ca2a4e0
 
 
 
 
8a5e693
fc250c3
ad35dba
8a5e693
 
 
 
ca2a4e0
8a5e693
ad35dba
 
 
ca2a4e0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from datasets import load_dataset, Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm

# Assuming PRODUCTS_10k_DATASET and CAPTIONING_MODEL_NAME are defined in config.py
from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageCaptioner:
    """
    A class for generating captions for images using a pre-trained model.

    Args:
        dataset (str): The path to the dataset.
        processor (str): The pre-trained processor model to use for image processing.
        model (str): The pre-trained model to use for caption generation.
        prompt (str): The conditioning prompt to use for caption generation.

    Attributes:
        dataset: The loaded dataset.
        processor: The pre-trained processor model.
        model: The pre-trained caption generation model.
        prompt: The conditioning prompt for generating captions.

    Methods:
        process_dataset: Preprocesses the dataset.
        generate_caption: Generates a caption for a single image.
        generate_captions: Generates captions for all images in the dataset.
    """

    def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"):
        self.dataset = load_dataset(dataset, split="test")
        self.dataset = self.dataset.select(range(10000))  # For demonstration purposes
        self.processor = BlipProcessor.from_pretrained(processor)
        self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
        self.prompt = prompt

    def process_dataset(self):
        """
        Preprocesses the dataset by renaming the image column and removing unwanted columns.

        Returns:
            The preprocessed dataset.
        """
        # Check if 'image' column exists, otherwise use 'pixel_values' if it exists
        image_column = "image" if "image" in self.dataset.column_names else "pixel_values"
        self.dataset = self.dataset.rename_column(image_column, "image")

        if "label" in self.dataset.column_names:
            self.dataset = self.dataset.remove_columns(["label"])

        # Add an empty 'text' column for captions if it doesn't exist
        if "text" not in self.dataset.column_names:
            new_column = [""] * len(self.dataset)
            self.dataset = self.dataset.add_column("text", new_column)

        return self.dataset

    def generate_caption(self, example):
        """
        Generates a caption for a single image.

        Args:
            example (dict): A dictionary containing the image data.

        Returns:
            dict: The dictionary with the generated caption.
        """
        image = example["image"].convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt").to(device)
        prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device)
        outputs = self.model.generate(**inputs, **prompt_inputs)
        blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
        example["text"] = blip_caption
        return example

    def generate_captions(self):
        """
        Generates captions for all images in the dataset.

        Returns:
            Dataset: The dataset with generated captions.
        """
        self.dataset = self.process_dataset()
        self.dataset = self.dataset.map(self.generate_caption, batched=False)
        return self.dataset

# Initialize ImageCaptioner
ic = ImageCaptioner(
    dataset=PRODUCTS_10k_DATASET,
    processor=CAPTIONING_MODEL_NAME,
    model=CAPTIONING_MODEL_NAME,
    prompt='Commercial photography of'
)

# Generate captions for the dataset
products10k_dataset = ic.generate_captions()

# Save the dataset to the hub
products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions")