chats-bug commited on
Commit
df766f8
1 Parent(s): 7d06c4c

Introduced Fine tuning for the models

Browse files
Files changed (2) hide show
  1. app.py +33 -8
  2. model.py +108 -19
app.py CHANGED
@@ -2,13 +2,23 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
 
5
- from model import GitBaseCocoModel
6
 
 
 
 
 
7
 
8
  def generate_captions(
9
  image,
10
- max_len,
11
  num_captions,
 
 
 
 
 
 
 
12
  ):
13
  """
14
  Generates captions for the given image.
@@ -28,14 +38,23 @@ def generate_captions(
28
  """
29
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
- checkpoint = "microsoft/git-base-coco"
32
 
33
- model = GitBaseCocoModel(device, checkpoint)
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- caption = model.generate(image, max_len, num_captions)
36
  # Convert list to a single string separated by newlines.
37
- caption = "\n".join(caption)
38
- return caption
39
 
40
  title = "Git-Base-COCO Image Captioning"
41
  description = "A model for generating captions for images."
@@ -44,8 +63,14 @@ interface = gr.Interface(
44
  fn=generate_captions,
45
  inputs=[
46
  gr.inputs.Image(type="pil", label="Image"),
47
- gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
48
  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
 
 
 
 
 
 
 
49
  ],
50
  outputs=[
51
  gr.outputs.Textbox(label="Caption"),
 
2
  import torch
3
  from PIL import Image
4
 
5
+ from model import GitBaseCocoModel, BlipBaseModel
6
 
7
+ MODELS = {
8
+ "Git-Base-COCO": GitBaseCocoModel,
9
+ "Blip Base": BlipBaseModel,
10
+ }
11
 
12
  def generate_captions(
13
  image,
 
14
  num_captions,
15
+ max_length,
16
+ temperature,
17
+ top_k,
18
+ top_p,
19
+ repetition_penalty,
20
+ diversity_penalty,
21
+ model_name,
22
  ):
23
  """
24
  Generates captions for the given image.
 
38
  """
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
41
 
42
+ model = MODELS[model_name](device)
43
+
44
+ captions = model.generate(
45
+ image,
46
+ max_length,
47
+ num_captions,
48
+ temperature,
49
+ top_k,
50
+ top_p,
51
+ repetition_penalty,
52
+ diversity_penalty,
53
+ )
54
 
 
55
  # Convert list to a single string separated by newlines.
56
+ captions = "\n".join(captions)
57
+ return captions
58
 
59
  title = "Git-Base-COCO Image Captioning"
60
  description = "A model for generating captions for images."
 
63
  fn=generate_captions,
64
  inputs=[
65
  gr.inputs.Image(type="pil", label="Image"),
 
66
  gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
67
+ gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
68
+ gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"),
69
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
70
+ gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"),
71
+ gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"),
72
+ gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"),
73
+ gr.Inputs.Dropdown(MODELS.keys(), label="Model"),
74
  ],
75
  outputs=[
76
  gr.outputs.Textbox(label="Caption"),
model.py CHANGED
@@ -1,7 +1,81 @@
1
- from transformers import AutoProcessor, AutoModelForCausalLM
2
 
3
- class GitBaseCocoModel:
4
- def __init__(self, device, checkpoint="microsoft/git-base-coco"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
  A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.
7
 
@@ -16,12 +90,12 @@ class GitBaseCocoModel:
16
  Returns:
17
  None
18
  """
19
- self.checkpoint = checkpoint
20
- self.device = device
21
- self.processor = AutoProcessor.from_pretrained(self.checkpoint)
22
- self.model = AutoModelForCausalLM.from_pretrained(self.checkpoint).to(self.device)
23
 
24
- def generate(self, image, max_len=50, num_captions=1):
25
  """
26
  Generates captions for the given image.
27
 
@@ -34,14 +108,29 @@ class GitBaseCocoModel:
34
  num_captions: int
35
  The number of captions to generate.
36
  """
37
- pixel_values = self.processor(
38
- images=image, return_tensors="pt"
39
- ).pixel_values.to(self.device)
40
- generated_ids = self.model.generate(
41
- pixel_values=pixel_values,
42
- max_length=max_len,
43
- num_beams=num_captions,
44
- num_return_sequences=num_captions,
45
- )
46
- return self.processor.batch_decode(generated_ids, skip_special_tokens=True)
47
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, AutoModelForCausalLM, BlipForConditionalGeneration
2
 
3
+ class ImageCaptionModel:
4
+ def __init__(
5
+ self,
6
+ device,
7
+ processor,
8
+ model,
9
+ ) -> None:
10
+ self.device = device
11
+ self.processor = processor
12
+ self.model = model
13
+ self.model.to(self.device)
14
+
15
+ def generate(
16
+ self,
17
+ image,
18
+ num_captions=1,
19
+ max_length=50,
20
+ num_beam_groups=1,
21
+ temperature=1.0,
22
+ top_k=50,
23
+ top_p=1.0,
24
+ repetition_penalty=1.0,
25
+ diversity_penalty=0.0,
26
+ ):
27
+ """
28
+ Generates captions for the given image.
29
+
30
+ -----
31
+ Parameters:
32
+ preprocessor: transformers.PreTrainedTokenizerFast
33
+ The preprocessor to use for the model.
34
+ model: transformers.PreTrainedModel
35
+ The model to use for generating captions.
36
+ image: PIL.Image
37
+ The image to generate captions for.
38
+ num_captions: int
39
+ The number of captions to generate.
40
+ num_beam_groups: int
41
+ The number of beam groups to use for beam search in order to maintain diversity. Must be between 1 and num_beams. 1 means no group_beam_search..
42
+ temperature: float
43
+ The temperature to use for sampling. The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive. Defaults to 1.0.
44
+ top_k: int
45
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. A large value of top_k will keep more probabilities for each token leading to a better but slower generation. Defaults to 50.
46
+ top_p: float
47
+ The value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
48
+ repetition_penalty: float
49
+ The parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.0.
50
+ diversity_penalty: float
51
+ The parameter for diversity penalty. 0.0 means no penalty. Defaults to 0.0.
52
+
53
+ """
54
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
55
+
56
+ if diversity_penalty != 0.0:
57
+ num_beam_groups = 2
58
+ num_captions = num_captions if num_captions % 2 == 0 else num_captions + 1
59
+
60
+ generated_ids = self.model.generate(
61
+ pixel_values=pixel_values,
62
+ max_length=max_length,
63
+ num_beams=num_captions,
64
+ num_beam_groups=num_beam_groups,
65
+ num_return_sequences=num_captions,
66
+ temperature=temperature,
67
+ top_k=top_k,
68
+ top_p=top_p,
69
+ repetition_penalty=repetition_penalty,
70
+ diversity_penalty=diversity_penalty,
71
+ )
72
+
73
+ generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
74
+
75
+ return generated_caption[:num_captions]
76
+
77
+ class GitBaseCocoModel(ImageCaptionModel):
78
+ def __init__(self, device):
79
  """
80
  A wrapper class for the Git-Base-COCO model. It is a pretrained model for image captioning.
81
 
 
90
  Returns:
91
  None
92
  """
93
+ checkpoint = "microsoft/git-base-coco"
94
+ processor = AutoProcessor.from_pretrained(checkpoint)
95
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
96
+ super().__init__(device, processor, model)
97
 
98
+ def generate(self, image, max_length=50, num_captions=1, **kwargs):
99
  """
100
  Generates captions for the given image.
101
 
 
108
  num_captions: int
109
  The number of captions to generate.
110
  """
111
+ captions = super().generate(image, max_length, num_captions, **kwargs)
112
+ return captions
113
+
114
+
115
+ class BlipBaseModel(ImageCaptionModel):
116
+ def __init__(self, device):
117
+ self.checkpoint = "Salesforce/blip-image-captioning-base"
118
+ processor = AutoProcessor.from_pretrained(self.checkpoint)
119
+ model = BlipForConditionalGeneration.from_pretrained(self.checkpoint)
120
+ super().__init__(device, processor, model)
121
+
122
+ def generate(self, image, max_length=50, num_captions=1, **kwargs):
123
+ """
124
+ Generates captions for the given image.
125
+
126
+ -----
127
+ Parameters:
128
+ image: PIL.Image
129
+ The image to generate captions for.
130
+ max_len: int
131
+ The maximum length of the caption.
132
+ num_captions: int
133
+ The number of captions to generate.
134
+ """
135
+ captions = super().generate(image, max_length, num_captions, **kwargs)
136
+ return captions