File size: 3,136 Bytes
98d86db
 
 
 
 
 
dae3891
98d86db
 
 
 
 
 
 
 
 
 
 
 
 
 
dae3891
98d86db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dae3891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98d86db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
from pathlib import Path
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
import io

class Caption:
    def __init__(self):
        self.model = VisionEncoderDecoderModel.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )
        self.feature_extractor = ViTImageProcessor.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            "nlpconnect/vit-gpt2-image-captioning"
        )

        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.max_length = 16
        self.num_beams = 4
        self.gen_kwargs = {"max_length": self.max_length, "num_beams": self.num_beams}


    def predict_step(self,image_paths):
        images = []
        
        for image_path in image_paths:
            i_image = Image.open(image_path)
            if i_image.mode != "RGB":
                i_image = i_image.convert(mode="RGB")
                
            images.append(i_image)

        pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(self.device)

        output_ids = self.model.generate(pixel_values, **self.gen_kwargs)

        preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]
        return preds
    
    def predict_from_memory(self, image_buffers):
        images = []

        for image_buffer in image_buffers:
            # Ensure the buffer is positioned at the start
            if isinstance(image_buffer, io.BytesIO):
                image_buffer.seek(0)
            try:
                i_image = Image.open(image_buffer)
                if i_image.mode != "RGB":
                    i_image = i_image.convert("RGB")
                images.append(i_image)
            except Exception as e:
                print(f"Failed to process image buffer: {str(e)}")
                continue

        return self.process_images(images)
    
    def process_images(self, images):
        pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
        pixel_values = pixel_values.to(self.device)
        output_ids = self.model.generate(pixel_values, **self.gen_kwargs)
        preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]
        return preds

    def get_args(self):
        parser = argparse.ArgumentParser()
        parser.add_argument( "-i",
        "--input_img_paths",
        type=str,
        default="farmer.jpg",
        help="img for caption")
        
        args = parser.parse_args()
    
        return args
	
if __name__ == "__main__":
    model = Caption()
    args = model.get_args()
    image_paths = []
    image_paths.append(args.input_img_paths)
    print(model.predict_step(image_paths))