File size: 3,681 Bytes
6a6f954
 
 
 
 
 
37ce9d7
 
6a6f954
 
 
 
37ce9d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a6f954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ce9d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a6f954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
from pathlib import Path
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import io
import google.generativeai as genai

class Caption:
    def __init__(self):

        self.api_key = 'AIzaSyAFG94rVbm9eWepO5uPGsMha8XJ-sHbMdA'
        genai.configure(api_key=self.api_key)
        self.model = genai.GenerativeModel(model_name="gemini-pro-vision")
    #     self.model = VisionEncoderDecoderModel.from_pretrained(
    #         "nlpconnect/vit-gpt2-image-captioning"
    #     )
    #     self.feature_extractor = ViTImageProcessor.from_pretrained(
    #         "nlpconnect/vit-gpt2-image-captioning"
    #     )
    #     self.tokenizer = AutoTokenizer.from_pretrained(
    #         "nlpconnect/vit-gpt2-image-captioning"
    #     )

    #     # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #     self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #     self.model.to(self.device)
    #     self.max_length = 16
    #     self.num_beams = 4
    #     self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}


    def predict_step(self,image_paths):
        images = []
        
        for image_path in image_paths:
            i_image = Image.open(image_path)
            if i_image.mode != "RGB":
                i_image = i_image.convert(mode="RGB")
                
            images.append(i_image)

        pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(self.device)

        output_ids = self.model.generate(pixel_values, **self.gen_kwargs)

        preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]
        return preds
    
    def predict_from_memory(self, image_buffers):
        images = []

        for image_buffer in image_buffers:
            # Ensure the buffer is positioned at the start
            if isinstance(image_buffer, io.BytesIO):
                image_buffer.seek(0)
            try:
                i_image = Image.open(image_buffer)
                if i_image.mode != "RGB":
                    i_image = i_image.convert("RGB")
                images.append(i_image)
            except Exception as e:
                print(f"Failed to process image buffer: {str(e)}")
                continue

        return self.process_images(images)
    
    def process_images(self, images):
        pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(self.device)
        output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
        preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]
        return preds

    def predict_image_caption_gemini(self,img):
        prompt = "Describe the main focus of this image in detail."
        response = self.model.generate_content([prompt, img], stream=True)
        response.resolve()
        print("Derived data",response.text)
        return response.text

    def get_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument( "-i",
        "--input_img_paths",
        type=str,
        default="farmer.jpg",
        help="img for caption")
        
        args = parser.parse_args()
    
        return args
	
if __name__ == "__main__":
    model = Caption()
    args = model.get_args()
    image_paths = []
    image_paths.append(args.input_img_paths)
    print(model.predict_step(image_paths))