bilgeyucel commited on
Commit
3e6ae58
1 Parent(s): dc74825

Upload image_captioner.py

Browse files
Files changed (1) hide show
  1. image_captioner.py +75 -0
image_captioner.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import requests
3
+
4
+ import logging
5
+
6
+ from haystack import Document, component
7
+ from haystack.lazy_imports import LazyImport
8
+ from PIL import Image
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
13
+ import torch
14
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration
15
+ from PIL import Image
16
+
17
+ @component
18
+ class ImageCaptioner:
19
+ def __init__(
20
+ self,
21
+ model_name: str = "Salesforce/blip-image-captioning-base",
22
+ ):
23
+ torch_and_transformers_import.check()
24
+ self.model_name = model_name
25
+
26
+ if model_name == "nlpconnect/vit-gpt2-image-captioning":
27
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_name)
28
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_name)
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ max_length = 16
31
+ num_beams = 4
32
+ self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
33
+ else:
34
+ self.processor = BlipProcessor.from_pretrained(model_name)
35
+ self.model = BlipForConditionalGeneration.from_pretrained(model_name)
36
+
37
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ self.model.to(self.device)
39
+
40
+ @component.output_types(captions=List[str])
41
+ def run(self, image_file_paths: List[str]) -> List[Document]:
42
+
43
+ images = []
44
+ for image_path in image_file_paths:
45
+ i_image = Image.open(image_path)
46
+ if i_image.mode != "RGB":
47
+ i_image = i_image.convert(mode="RGB")
48
+
49
+ images.append(i_image)
50
+
51
+ preds = []
52
+ if self.model_name == "nlpconnect/vit-gpt2-image-captioning":
53
+ pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
54
+ pixel_values = pixel_values.to(self.device)
55
+
56
+ output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
57
+
58
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
59
+ preds = [pred.strip() for pred in preds]
60
+ else:
61
+
62
+ inputs = self.processor(images, return_tensors="pt")
63
+ output_ids = self.model.generate(**inputs)
64
+ preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
65
+ preds = [pred.strip() for pred in preds]
66
+
67
+ # captions: List[Document] = []
68
+ # for caption, image_file_path in zip(preds, image_file_paths):
69
+ # document = Document(content=caption, meta={"image_path": image_file_path})
70
+ # captions.append(document)
71
+ return {"captions": preds}
72
+
73
+ # captioner = ImageCaptioner(model_name="Salesforce/blip-image-captioning-base")
74
+ # result = captioner.run(image_file_paths=["selfie.png"])
75
+ # print(result)