VikramSingh178 commited on
Commit
8a5e693
β€’
1 Parent(s): fc250c3

refactor: Update image captioning script to use Salesforce/blip-image-captioning-large model

Browse files
Files changed (1) hide show
  1. scripts/products10k_captions.py +28 -39
scripts/products10k_captions.py CHANGED
@@ -1,55 +1,44 @@
1
  from datasets import load_dataset
2
- from config import (PRODUCTS_10k_DATASET,CAPTIONING_MODEL_NAME)
3
- from transformers import (BlipProcessor, BlipForConditionalGeneration)
4
  from tqdm import tqdm
5
  import torch
6
 
7
-
8
-
9
-
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
-
13
 
14
  class ImageCaptioner:
15
-
16
- def __init__(self, dataset:str,processor:str,model:str):
17
- self.dataset = load_dataset(dataset)
18
  self.processor = BlipProcessor.from_pretrained(processor)
19
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
20
-
21
-
22
  def process_dataset(self):
23
- self.dataset = self.dataset.rename_column(original_column_name='pixel_values',new_column_name='image')
24
- self.dataset = self.dataset.remove_columns(column_names=['label'])
 
 
 
25
  return self.dataset
26
-
27
-
28
  def generate_captions(self):
29
  self.dataset = self.process_dataset()
30
- self.dataset['image']=[image.convert("RGB") for image in self.dataset["image"]]
31
- print(self.dataset['image'][0])
32
- for image in tqdm(self.dataset['image']):
33
- inputs = self.processor(image, return_tensors="pt").to(device)
34
- out = self.model(**inputs)
35
-
36
-
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
- ic = ImageCaptioner(dataset=PRODUCTS_10k_DATASET,processor=CAPTIONING_MODEL_NAME,model=CAPTIONING_MODEL_NAME)
45
-
46
 
 
 
 
 
 
 
 
47
 
 
 
48
 
 
49
 
50
-
51
-
52
-
53
-
54
-
55
-
 
1
  from datasets import load_dataset
2
+ from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME
3
+ from transformers import BlipProcessor, BlipForConditionalGeneration
4
  from tqdm import tqdm
5
  import torch
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
8
 
9
  class ImageCaptioner:
10
+ def __init__(self, dataset: str, processor: str, model: str):
11
+ self.dataset = load_dataset(dataset, split="train")
 
12
  self.processor = BlipProcessor.from_pretrained(processor)
13
  self.model = BlipForConditionalGeneration.from_pretrained(model).to(device)
14
+
 
15
  def process_dataset(self):
16
+ # Assuming 'pixel_values' is the column name for images in the dataset
17
+ self.dataset = self.dataset.rename_column("pixel_values", "image")
18
+ # Remove unwanted columns
19
+ if "label" in self.dataset.column_names:
20
+ self.dataset = self.dataset.remove_columns(["label"])
21
  return self.dataset
22
+
 
23
  def generate_captions(self):
24
  self.dataset = self.process_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ for idx in tqdm(range(len(self.dataset))):
27
+ image = self.dataset[idx]["image"].convert("RGB")
28
+ inputs = self.processor(images=image, return_tensors="pt").to(device)
29
+ outputs = self.model.generate(**inputs)
30
+ blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True)
31
+ self.dataset[idx]["caption"] = blip_caption
32
+ print(f"Caption for image {idx}: {blip_caption}")
33
 
34
+ # Optionally, you can save the dataset with captions to disk
35
+ # self.dataset.save_to_disk('path_to_save_dataset')
36
 
37
+ return self.dataset
38
 
39
+ ic = ImageCaptioner(
40
+ dataset=PRODUCTS_10k_DATASET,
41
+ processor=CAPTIONING_MODEL_NAME,
42
+ model=CAPTIONING_MODEL_NAME,
43
+ )
44
+ ic.generate_captions()