File size: 5,339 Bytes
df766f8
0d08077
df766f8
 
 
 
 
 
 
a95ba86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df766f8
 
 
 
a95ba86
df766f8
 
 
a95ba86
 
 
 
 
 
 
df766f8
 
 
a95ba86
df766f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a95ba86
df766f8
a95ba86
 
 
 
 
 
 
 
df766f8
a95ba86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df766f8
 
1d4f82c
df766f8
2bcaca6
df766f8
 
0d08077
 
 
 
 
a95ba86
 
0d08077
 
 
 
 
 
 
df766f8
 
 
 
0d08077
df766f8
 
 
 
a95ba86
df766f8
 
 
a95ba86
 
 
 
 
 
 
 
df766f8
a95ba86
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration

class ImageCaptionModel:
	def __init__(
		self,
		device,
		processor,
		model,
	) -> None:
		"""
		Initializes the model for generating captions for images.

		-----
		Parameters:
		device: str
			The device to use for the model. Must be either "cpu" or "cuda".
		processor: transformers.AutoProcessor
			The preprocessor to use for the model.
		model: transformers.AutoModelForCausalLM or transformers.BlipForConditionalGeneration
			The model to use for generating captions.

		-----
		Returns:
		None
		"""
		self.device = device
		self.processor = processor
		self.model = model
		self.model.to(self.device)

	def generate(
		self,
		image,
		num_captions: int = 1,
		max_length: int = 50,
		temperature: float = 1.0,
		top_k: int = 50,
		top_p: float = 1.0,
		repetition_penalty: float = 1.0,
		diversity_penalty: float = 0.0,
	):
		"""
		Generates captions for the given image.

		-----
		Parameters:
		preprocessor: transformers.PreTrainedTokenizerFast
			The preprocessor to use for the model.
		model: transformers.PreTrainedModel	
			The model to use for generating captions.
		image: PIL.Image
			The image to generate captions for.
		num_captions: int
			The number of captions to generate.
		temperature: float
			The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
		top_k: int
			The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50.
		top_p: float
			The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
		repetition_penalty: float
			The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
		diversity_penalty: float
			The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.

		"""
		# Type checking and making sure the values are valid.
		assert type(num_captions) == int and num_captions > 0, "num_captions must be a positive integer."
		assert type(max_length) == int and max_length > 0, "max_length must be a positive integer."
		assert type(temperature) == float and temperature > 0.0, "temperature must be a positive float."
		assert type(top_k) == int and top_k > 0, "top_k must be a positive integer."
		assert type(top_p) == float and top_p > 0.0, "top_p must be a positive float."
		assert type(repetition_penalty) == float and repetition_penalty >= 1.0, "repetition_penalty must be a positive float greater than or equal to 1."
		assert type(diversity_penalty) == float and diversity_penalty >= 0.0, "diversity_penalty must be a non negative float."

		pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) # Convert the image to pixel values.

		# Generate captions ids.
		if num_captions == 1:
			generated_ids = self.model.generate(
				pixel_values=pixel_values,
				max_length=max_length,
				num_return_sequences=1,
				temperature=temperature,
				top_k=top_k,
				top_p=top_p,
			)
		else:
			generated_ids = self.model.generate(
				pixel_values=pixel_values,
				max_length=max_length,
				num_beams=num_captions, # num_beams must be greater than or equal to num_captions and must be divisible by num_beam_groups.
				num_beam_groups=num_captions, # num_beam_groups is set to equal to num_captions so that all the captions are diverse
				num_return_sequences=num_captions, # generate multiple captions which are very similar to each other due to the grouping effect of beam search.
				temperature=temperature,
				top_k=top_k,
				top_p=top_p,
				repetition_penalty=repetition_penalty,
				diversity_penalty=diversity_penalty,
			)
			
		# Decode the generated ids to get the captions.
		generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

		return generated_caption


class GitBaseCocoModel(ImageCaptionModel):
	def __init__(self, device):
		"""
		A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.

		-----
		Parameters:
		device: str
			The device to run the model on, either "cpu" or "cuda".
		checkpoint: str
			The checkpoint to load the model from.

		-----
		Returns:
		None
		"""
		checkpoint = "microsoft/git-base-coco"
		processor = AutoProcessor.from_pretrained(checkpoint)
		model = AutoModelForCausalLM.from_pretrained(checkpoint)
		super().__init__(device, processor, model)


class BlipBaseModel(ImageCaptionModel):
	def __init__(self, device):
		"""
		A wrapper class for the Blip-Base model. It is a pretrained model for image captioning.

		-----
		Parameters:
		device: str
			The device to run the model on, either "cpu" or "cuda".
		checkpoint: str
			The checkpoint to load the model from.

		-----
		Returns:
		None
		"""
		self.checkpoint = "Salesforce/blip-image-captioning-base"
		processor = AutoProcessor.from_pretrained(self.checkpoint)
		model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
		super().__init__(device, processor, model)